#!/usr/bin/env python3

# Copyright (C) SchedMD LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, List, Tuple, Optional
import argparse
import base64
import collections
import hashlib
import importlib.util
import inspect
import json
import logging
import logging.config
import math
import os
import re
import shelve
import shlex
import shutil
import socket
import subprocess
import sys
import tempfile
from enum import Enum
from collections import defaultdict, namedtuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from functools import lru_cache, reduce, wraps
from itertools import chain, compress, islice
from pathlib import Path
from time import sleep, time

import slurm_gcp_plugins

required_modules = [
    ("googleapiclient", "google-api-python-client"),
    ("requests", "requests"),
    ("yaml", "yaml"),
    ("addict", "addict"),
    ("httplib2", "httplib2"),
    ("google.cloud.tpu_v2", "google-cloud-tpu"),
]
missing_imports = False
can_tpu = True
for module, name in required_modules:
    if importlib.util.find_spec(module) is None:
        if module == "google.cloud.tpu_v2":
            can_tpu = False
            print(
                f"WARNING: Missing Python module '{module} (pip:{name})', TPU support will not work."
            )
        else:
            missing_imports = True
            print(f"ERROR: Missing Python module '{module} (pip:{name})'")
if missing_imports:
    print("Aborting due to missing Python modules")
    exit(1)

import google.auth  # noqa: E402
from google.oauth2 import service_account  # noqa: E402
import googleapiclient.discovery  # noqa: E402
import google_auth_httplib2  # noqa: E402
from googleapiclient.http import set_user_agent  # noqa: E402
from google.api_core.client_options import ClientOptions  # noqa: E402
import httplib2  # noqa: E402

if can_tpu:
    from google.cloud import tpu_v2 as tpu  # noqa: E402
import google.api_core.exceptions as gExceptions  # noqa: E402

from requests import get as get_url  # noqa: E402
from requests.exceptions import RequestException  # noqa: E402

import yaml  # noqa: E402
from addict import Dict as NSDict  # noqa: E402

optional_modules = [
    ("google.cloud.secretmanager", "google-cloud-secret-manager"),
]
for module, name in optional_modules:
    if importlib.util.find_spec(module) is None:
        print(f"WARNING: Missing Python module '{module}' (pip:{name}) ")

USER_AGENT = "Slurm_GCP_Scripts/1.5 (GPN:SchedMD)"
ENV_CONFIG_YAML = os.getenv("SLURM_CONFIG_YAML")
if ENV_CONFIG_YAML:
    CONFIG_FILE = Path(ENV_CONFIG_YAML)
else:
    CONFIG_FILE = Path(__file__).with_name("config.yaml")
API_REQ_LIMIT = 2000
URI_REGEX = r"[a-z]([-a-z0-9]*[a-z0-9])?"


def mkdirp(path: Path) -> None:
    path.mkdir(parents=True, exist_ok=True)


scripts_dir = next(
    p for p in (Path(__file__).parent, Path("/slurm/scripts")) if p.is_dir()
)

# readily available compute api handle
compute = None
# slurm-gcp config object, could be empty if not available
cfg = NSDict()
# caching Lookup object
lkp = None

# load all directories as Paths into a dict-like namespace
dirs = NSDict(
    {
        n: Path(p)
        for n, p in dict.items(
            {
                "home": "/home",
                "apps": "/opt/apps",
                "slurm": "/slurm",
                "scripts": scripts_dir,
                "custom_scripts": "/slurm/custom_scripts",
                "munge": "/etc/munge",
                "secdisk": "/mnt/disks/sec",
                "log": "/var/log/slurm",
            }
        )
    }
)

slurmdirs = NSDict(
    {
        n: Path(p)
        for n, p in dict.items(
            {
                "prefix": "/usr/local",
                "etc": "/usr/local/etc/slurm",
                "state": "/var/spool/slurm",
            }
        )
    }
)


yaml.SafeDumper.yaml_representers[
    None
] = lambda self, data: yaml.representer.SafeRepresenter.represent_str(self, str(data))


class ApiEndpoint(Enum):
    COMPUTE = "compute"
    BQ = "bq"
    STORAGE = "storage"
    TPU = "tpu"
    SECRET = "secret_manager"


@lru_cache(maxsize=1)
def default_credentials():
    return google.auth.default()[0]


@lru_cache(maxsize=1)
def authentication_project():
    return google.auth.default()[1]


DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"


def universe_domain() -> str:
    try:
        return instance_metadata("attributes/universe_domain")
    except Exception:
        return DEFAULT_UNIVERSE_DOMAIN


def endpoint_version(api: ApiEndpoint) -> Optional[str]:
    if api and api.value in lkp.endpoint_versions:
        return lkp.endpoint_versions[api.value]
    return None


@lru_cache(maxsize=1)
def get_credentials() -> Optional[service_account.Credentials]:
    """Get credentials for service account"""
    key_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
    if key_path is not None:
        credentials = service_account.Credentials.from_service_account_file(
            key_path, scopes=[f"https://www.{universe_domain()}/auth/cloud-platform"]
        )
    else:
        credentials = default_credentials()

    return credentials


def create_client_options(api: ApiEndpoint = None) -> ClientOptions:
    """Create client options for cloud endpoints"""
    ver = endpoint_version(api)
    ud = universe_domain()
    options = {}
    if ud and ud != DEFAULT_UNIVERSE_DOMAIN:
        options["universe_domain"] = ud
    if ver:
        options["api_endpoint"] = f"https://{api.value}.{ud}/{ver}/"
    co = ClientOptions(**options)
    log.debug(f"Using ClientOptions = {co} for API: {api.value}")
    return co


class LogFormatter(logging.Formatter):
    """adds logging flags to the levelname in log records"""

    def format(self, record):
        new_fmt = self._fmt
        flag = getattr(record, "flag", None)
        if flag is not None:
            start, level, end = new_fmt.partition("%(levelname)s")
            if level:
                new_fmt = f"{start}{level}(%(flag)s){end}"
        # insert function name if record level is DEBUG
        if record.levelno < logging.INFO:
            prefix, msg, suffix = new_fmt.partition("%(message)s")
            new_fmt = f"{prefix}%(funcName)s: {msg}{suffix}"
        self._style._fmt = new_fmt
        return super().format(record)


class FlagLogAdapter(logging.LoggerAdapter):
    """creates log adapters that add a flag to the log record,
    allowing it to be filtered"""

    def __init__(self, logger, flag, extra=None):
        if extra is None:
            extra = {}
        self.flag = flag
        super().__init__(logger, extra)

    @property
    def enabled(self):
        return cfg.extra_logging_flags.get(self.flag, False)

    def process(self, msg, kwargs):
        extra = kwargs.setdefault("extra", {})
        extra.update(self.extra)
        extra["flag"] = self.flag
        return msg, kwargs


logging.basicConfig(level=logging.INFO, stream=sys.stdout)
log = logging.getLogger(__name__)
logging_flags = [
    "trace_api",
    "subproc",
    "hostlists",
]
log_trace_api = FlagLogAdapter(log, "trace_api")
log_subproc = FlagLogAdapter(log, "subproc")
log_hostlists = FlagLogAdapter(log, "hostlists")


def access_secret_version(project_id, secret_id, version_id="latest"):
    """
    Access the payload for the given secret version if one exists. The version
    can be a version number as a string (e.g. "5") or an alias (e.g. "latest").
    """
    from google.cloud import secretmanager
    from google.api_core import exceptions

    co = create_client_options(ApiEndpoint.SECRET)
    client = secretmanager.SecretManagerServiceClient(client_options=co)
    name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"
    try:
        response = client.access_secret_version(request={"name": name})
        log.debug(f"Secret '{name}' was found.")
        payload = response.payload.data.decode("UTF-8")
    except exceptions.NotFound:
        log.debug(f"Secret '{name}' was not found!")
        payload = None

    return payload


def parse_self_link(self_link: str):
    """Parse a selfLink url, extracting all useful values
    https://.../v1/projects/<project>/regions/<region>/...
    {'project': <project>, 'region': <region>, ...}
    can also extract zone, instance (name), image, etc
    """
    link_patt = re.compile(r"(?P<key>[^\/\s]+)s\/(?P<value>[^\s\/]+)")
    return NSDict(link_patt.findall(self_link))


def parse_bucket_uri(uri: str):
    """
    Parse a bucket url
    E.g. gs://<bucket_name>/<path>
    """
    pattern = re.compile(r"gs://(?P<bucket>[^/\s]+)/(?P<path>([^/\s]+)(/[^/\s]+)*)")
    matches = pattern.match(uri)
    return matches.group("bucket"), matches.group("path")


def trim_self_link(link: str):
    """get resource name from self link url, eg.
    https://.../v1/projects/<project>/regions/<region>
    -> <region>
    """
    try:
        return link[link.rindex("/") + 1 :]
    except ValueError:
        raise Exception(f"'/' not found, not a self link: '{link}' ")


def execute_with_futures(func, seq):
    with ThreadPoolExecutor() as exe:
        futures = []
        for i in seq:
            future = exe.submit(func, i)
            futures.append(future)
        for future in as_completed(futures):
            result = future.exception()
            if result is not None:
                raise result


def map_with_futures(func, seq):
    with ThreadPoolExecutor() as exe:
        futures = []
        for i in seq:
            future = exe.submit(func, i)
            futures.append(future)
        for future in futures:
            # Will be result or raise Exception
            res = None
            try:
                res = future.result()
            except Exception as e:
                res = e
            yield res


def blob_get(file, project=None):
    from google.cloud import storage

    if project is None:
        project = lkp.project
    uri = instance_metadata("attributes/slurm_bucket_path")
    bucket_name, path = parse_bucket_uri(uri)
    blob_name = f"{path}/{file}"
    co = create_client_options(ApiEndpoint.STORAGE)
    storage_client = storage.Client(project=project, client_options=co)
    return storage_client.get_bucket(bucket_name).blob(blob_name)


def blob_list(prefix="", delimiter=None, project=None):
    from google.cloud import storage

    if project is None:
        project = lkp.project
    uri = instance_metadata("attributes/slurm_bucket_path")
    bucket_name, path = parse_bucket_uri(uri)
    blob_prefix = f"{path}/{prefix}"
    co = create_client_options(ApiEndpoint.STORAGE)
    storage_client = storage.Client(project=project, client_options=co)
    # Note: The call returns a response only when the iterator is consumed.
    blobs = storage_client.list_blobs(
        bucket_name, prefix=blob_prefix, delimiter=delimiter
    )
    return [blob for blob in blobs]


def _hash_file(fullpath):
    with open(fullpath, "rb") as f:
        file_hash = hashlib.md5()
        chunk = f.read(8192)
        while chunk:
            file_hash.update(chunk)
            chunk = f.read(8192)
    return base64.b64encode(file_hash.digest()).decode("utf-8")


def install_custom_scripts(check_hash=False):
    """download custom scripts from gcs bucket"""

    compute_tokens = ["compute", "prolog", "epilog"]
    if lkp.instance_role == "compute":
        try:
            compute_tokens.append(f"nodeset-{lkp.node_nodeset_name()}")
        except Exception as e:
            log.error(f"Failed to lookup nodeset: {e}")

    prefix_tokens = dict.get(
        {
            "login": ["login"],
            "compute": compute_tokens,
            "controller": ["controller", "prolog", "epilog"],
        },
        lkp.instance_role,
        [],
    )
    prefixes = [f"slurm-{tok}-script" for tok in prefix_tokens]
    blobs = list(chain.from_iterable(blob_list(prefix=p) for p in prefixes))

    script_pattern = re.compile(r"slurm-(?P<path>\S+)-script-(?P<name>\S+)")
    for blob in blobs:
        m = script_pattern.match(Path(blob.name).name)
        if not m:
            log.warning(f"found blob that doesn't match expected pattern: {blob.name}")
            continue
        path_parts = m["path"].split("-")
        path_parts[0] += ".d"
        stem, _, ext = m["name"].rpartition("_")
        filename = ".".join((stem, ext))

        path = Path(*path_parts, filename)
        fullpath = (dirs.custom_scripts / path).resolve()
        mkdirp(fullpath.parent)

        for par in path.parents:
            chown_slurm(dirs.custom_scripts / par)
        need_update = True
        if check_hash and fullpath.exists():
            need_update = _hash_file(fullpath) != blob.md5_hash
        if need_update:
            log.info(f"installing custom script: {path} from {blob.name}")
            with fullpath.open("wb") as f:
                blob.download_to_file(f)
            chown_slurm(fullpath, mode=0o755)


def reservation_resource_policies(reservation):
    """
    Inspects reservation object, returns list of resource policies names.
    Converts policy URLs to names, e.g.:
    projects/111111/regions/us-central1/resourcePolicies/zebra -> zebra
    """
    return [u.split("/")[-1] for u in reservation.get("resourcePolicies", {}).values()]


def compute_service(credentials=None, user_agent=USER_AGENT, version="beta"):
    """Make thread-safe compute service handle
    creates a new Http for each request
    """

    credentials = get_credentials()

    def build_request(http, *args, **kwargs):
        new_http = httplib2.Http()
        if user_agent is not None:
            new_http = set_user_agent(new_http, user_agent)
        if credentials is not None:
            new_http = google_auth_httplib2.AuthorizedHttp(credentials, http=new_http)
        return googleapiclient.http.HttpRequest(new_http, *args, **kwargs)

    ver = endpoint_version(ApiEndpoint.COMPUTE)
    disc_url = googleapiclient.discovery.DISCOVERY_URI
    if ver:
        version = ver
        disc_url = disc_url.replace(DEFAULT_UNIVERSE_DOMAIN, universe_domain())

    log.debug(f"Using version={version} of Google Compute Engine API")
    return googleapiclient.discovery.build(
        "compute",
        version,
        requestBuilder=build_request,
        credentials=credentials,
        discoveryServiceUrl=disc_url,
    )


def load_config_data(config):
    """load dict-like data into a config object"""
    cfg = NSDict(config)
    if not cfg.slurm_log_dir:
        cfg.slurm_log_dir = dirs.log
    if not cfg.slurm_bin_dir:
        cfg.slurm_bin_dir = slurmdirs.prefix / "bin"
    if not cfg.slurm_control_host:
        cfg.slurm_control_host = f"{cfg.slurm_cluster_name}-controller"
    if not cfg.slurm_control_host_port:
        cfg.slurm_control_host_port = "6820-6830"
    if not cfg.munge_mount:
        # NOTE: should only happen with cloud controller
        cfg.munge_mount = NSDict(
            {
                "server_ip": cfg.slurm_control_addr or cfg.slurm_control_host,
                "remote_mount": "/etc/munge",
                "fs_type": "nfs",
                "mount_options": "defaults,hard,intr,_netdev",
            }
        )

    if not cfg.enable_debug_logging and isinstance(cfg.enable_debug_logging, NSDict):
        cfg.enable_debug_logging = False
    cfg.extra_logging_flags = NSDict(
        {flag: cfg.extra_logging_flags.get(flag, False) for flag in logging_flags}
    )
    return cfg


def new_config(config):
    """initialize a new config object
    necessary defaults are handled here
    """
    cfg = load_config_data(config)

    network_storage_iter = filter(
        None,
        (
            *cfg.network_storage,
            *cfg.login_network_storage,
            *chain.from_iterable(ns.network_storage for ns in cfg.nodeset.values()),
            *chain.from_iterable(ns.network_storage for ns in cfg.nodeset_dyn.values()),
            *chain.from_iterable(ns.network_storage for ns in cfg.nodeset_tpu.values()),
        ),
    )
    for netstore in network_storage_iter:
        if netstore != "gcsfuse" and (
            netstore.server_ip is None or netstore.server_ip == "$controller"
        ):
            netstore.server_ip = cfg.slurm_control_host
    return cfg


def fetch_config_yaml():
    """Fetch config.yaml from bucket"""
    config_yaml = blob_get("config.yaml").download_as_text()
    cfg = new_config(yaml.safe_load(config_yaml))
    return cfg


def fetch_config_yaml_md5():
    """Fetch config.yaml blob md5 from bucket"""
    import hashlib

    blob = blob_get("config.yaml")
    blob.reload()  # Populate blob with metadata
    hash_str = str(blob.md5_hash).encode(encoding="utf-8")
    return hashlib.md5(hash_str)


def load_config_file(path):
    """load config from file"""
    content = None
    try:
        content = yaml.safe_load(Path(path).read_text())
    except FileNotFoundError:
        log.warning(f"config file not found: {path}")
        return NSDict()
    return load_config_data(content)


def save_config(cfg, path):
    """save given config to file at path"""
    Path(path).write_text(yaml.dump(cfg, Dumper=Dumper))


def filter_logging_flags(record):
    """logging filter for flags
    if there are no flags, always pass. If there are flags, only pass if a flag
    matches an enabled flag in cfg.extra_logging_flags"""
    flag = getattr(record, "flag", None)
    if flag is None:
        return True
    return cfg.extra_logging_flags.get(flag, False)


def owned_file_handler(filename):
    """create file handler"""
    if filename is None:
        return None
    chown_slurm(filename)
    return logging.handlers.WatchedFileHandler(filename, delay=True)


def config_root_logger(caller_logger, level="DEBUG", stdout=True, logfile=None):
    """configure the root logger, disabling all existing loggers"""
    handlers = list(compress(("stdout_handler", "file_handler"), (stdout, logfile)))

    config = {
        "version": 1,
        "disable_existing_loggers": True,
        "formatters": {
            "standard": {
                "()": LogFormatter,
                "fmt": "%(levelname)s: %(message)s",
            },
            "stamp": {
                "()": LogFormatter,
                "fmt": "%(asctime)s %(levelname)s: %(message)s",
            },
        },
        "filters": {
            "logging_flags": {"()": lambda: filter_logging_flags},
        },
        "handlers": {
            "stdout_handler": {
                "level": logging.DEBUG,
                "formatter": "standard",
                "class": "logging.StreamHandler",
                "stream": sys.stdout,
                "filters": ["logging_flags"],
            },
            "file_handler": {
                "()": owned_file_handler,
                "level": logging.DEBUG,
                "formatter": "stamp",
                "filters": ["logging_flags"],
                "filename": logfile,
            },
        },
        "root": {
            "handlers": handlers,
            "level": level,
        },
    }
    if not logfile:
        del config["handlers"]["file_handler"]
    logging.config.dictConfig(config)
    loggers = (
        __name__,
        "resume",
        "suspend",
        "slurmsync",
        "setup",
        caller_logger,
    )
    for logger in map(logging.getLogger, loggers):
        logger.disabled = False


def log_api_request(request):
    """log.trace info about a compute API request"""
    if log_trace_api.enabled:
        # output the whole request object as pretty yaml
        # the body is nested json, so load it as well
        rep = json.loads(request.to_json())
        if rep.get("body", None) is not None:
            rep["body"] = json.loads(rep["body"])
        pretty_req = yaml.safe_dump(rep).rstrip()
        # label log message with the calling function
        log_trace_api.debug(f"{inspect.stack()[1].function}:\n{pretty_req}")


def handle_exception(exc_type, exc_value, exc_trace):
    """log exceptions other than KeyboardInterrupt"""
    # TODO does this work?
    if not issubclass(exc_type, KeyboardInterrupt):
        log.exception("Fatal exception", exc_info=(exc_type, exc_value, exc_trace))
    sys.__excepthook__(exc_type, exc_value, exc_trace)


def run(
    args,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    shell=False,
    timeout=None,
    check=True,
    universal_newlines=True,
    **kwargs,
):
    """Wrapper for subprocess.run() with convenient defaults"""
    if isinstance(args, list):
        args = list(filter(lambda x: x is not None, args))
        args = " ".join(args)
    if not shell and isinstance(args, str):
        args = shlex.split(args)
    log_subproc.debug(f"run: {args}")
    result = subprocess.run(
        args,
        stdout=stdout,
        stderr=stderr,
        shell=shell,
        timeout=timeout,
        check=check,
        universal_newlines=universal_newlines,
        **kwargs,
    )
    return result


def spawn(cmd, quiet=False, shell=False, **kwargs):
    """nonblocking spawn of subprocess"""
    if not quiet:
        log_subproc.debug(f"spawn: {cmd}")
    args = cmd if shell else shlex.split(cmd)
    return subprocess.Popen(args, shell=shell, **kwargs)


def chown_slurm(path: Path, mode=None) -> None:
    if path.exists():
        if mode:
            path.chmod(mode)
    else:
        mkdirp(path.parent)
        if mode:
            path.touch(mode=mode)
        else:
            path.touch()
    try:
        shutil.chown(path, user="slurm", group="slurm")
    except LookupError:
        log.warning(f"User 'slurm' does not exist. Cannot 'chown slurm:slurm {path}'.")
    except PermissionError:
        log.warning(f"Not authorized to 'chown slurm:slurm {path}'.")
    except Exception as err:
        log.error(err)


@contextmanager
def cd(path):
    """Change working directory for context"""
    prev = Path.cwd()
    os.chdir(path)
    try:
        yield
    finally:
        os.chdir(prev)


def cached_property(f):
    return property(lru_cache()(f))


def retry(max_retries: int, init_wait_time: float, warn_msg: str, exc_type: Exception):
    """Retries functions that raises the exception exc_type.
    Retry time is increased by a factor of two for every iteration.

    Args:
        max_retries (int): Maximum number of retries
        init_wait_time (float): Initial wait time in secs
        warn_msg (str): Message to print during retries
        exc_type (Exception): Exception type to check for
    """

    if max_retries <= 0:
        raise ValueError("Incorrect value for max_retries, must be >= 1")
    if init_wait_time <= 0.0:
        raise ValueError("Invalid value for init_wait_time, must be > 0.0")

    def decorator(f):
        @wraps(f)
        def wrapper(*args, **kwargs):
            retry = 0
            secs = init_wait_time
            captured_exc = None
            while retry < max_retries:
                try:
                    return f(*args, **kwargs)
                except exc_type as e:
                    captured_exc = e
                    log.warn(f"{warn_msg}, retrying in {secs}")
                    sleep(secs)
                    retry += 1
                    secs *= 2
            raise captured_exc

        return wrapper

    return decorator


def separate(pred, coll):
    """filter into 2 lists based on pred returning True or False
    returns ([False], [True])
    """
    return reduce(lambda acc, el: acc[pred(el)].append(el) or acc, coll, ([], []))


def chunked(iterable, n=API_REQ_LIMIT):
    """group iterator into chunks of max size n"""
    it = iter(iterable)
    while True:
        chunk = list(islice(it, n))
        if not chunk:
            return
        yield chunk


def groupby_unsorted(seq, key):
    indices = defaultdict(list)
    for i, el in enumerate(seq):
        indices[key(el)].append(i)
    for k, idxs in indices.items():
        yield k, (seq[i] for i in idxs)


@lru_cache(maxsize=32)
def find_ratio(a, n, s, r0=None):
    """given the start (a), count (n), and sum (s), find the ratio required"""
    if n == 2:
        return s / a - 1
    an = a * n
    if n == 1 or s == an:
        return 1
    if r0 is None:
        # we only need to know which side of 1 to guess, and the iteration will work
        r0 = 1.1 if an < s else 0.9

    # geometric sum formula
    def f(r):
        return a * (1 - r**n) / (1 - r) - s

    # derivative of f
    def df(r):
        rm1 = r - 1
        rn = r**n
        return (a * (rn * (n * rm1 - r) + r)) / (r * rm1**2)

    MIN_DR = 0.0001  # negligible change
    r = r0
    # print(f"r(0)={r0}")
    MAX_TRIES = 64
    for i in range(1, MAX_TRIES + 1):
        try:
            dr = f(r) / df(r)
        except ZeroDivisionError:
            log.error(f"Failed to find ratio due to zero division! Returning r={r0}")
            return r0
        r = r - dr
        # print(f"r({i})={r}")
        # if the change in r is small, we are close enough
        if abs(dr) < MIN_DR:
            break
    else:
        log.error(f"Could not find ratio after {MAX_TRIES}! Returning r={r0}")
        return r0
    return r


def backoff_delay(start, timeout=None, ratio=None, count: int = 0):
    """generates `count` waits starting at `start`
    sum of waits is `timeout` or each one is `ratio` bigger than the last
    the last wait is always 0"""
    # timeout or ratio must be set but not both
    assert (timeout is None) ^ (ratio is None)
    assert ratio is None or ratio > 0
    assert timeout is None or timeout >= start
    assert (count > 1 or timeout is not None) and isinstance(count, int)
    assert start > 0

    if count == 0:
        # Equation for auto-count is tuned to have a max of
        # ~int(timeout) counts with a start wait of <0.01.
        # Increasing start wait decreases count eg.
        # backoff_delay(10, timeout=60) -> count = 5
        count = int(
            (timeout / ((start + 0.05) ** (1 / 2)) + 2) // math.log(timeout + 2)
        )

    yield start
    # if ratio is set:
    # timeout = start * (1 - ratio**(count - 1)) / (1 - ratio)
    if ratio is None:
        ratio = find_ratio(start, count - 1, timeout)

    wait = start
    # we have start and 0, so we only need to generate count - 2
    for _ in range(count - 2):
        wait *= ratio
        yield wait
    yield 0
    return


ROOT_URL = "http://metadata.google.internal/computeMetadata/v1"


def get_metadata(path, root=ROOT_URL):
    """Get metadata relative to metadata/computeMetadata/v1"""
    HEADERS = {"Metadata-Flavor": "Google"}
    url = f"{root}/{path}"
    try:
        resp = get_url(url, headers=HEADERS)
        resp.raise_for_status()
        return resp.text
    except RequestException:
        log.debug(f"metadata not found ({url})")
        raise Exception(f"failed to get_metadata from {url}")


@lru_cache(maxsize=None)
def instance_metadata(path):
    """Get instance metadata"""
    return get_metadata(path, root=f"{ROOT_URL}/instance")


@lru_cache(maxsize=None)
def project_metadata(key):
    """Get project metadata project/attributes/<slurm_cluster_name>-<path>"""
    return get_metadata(key, root=f"{ROOT_URL}/project/attributes")


def bucket_blob_download(bucket_name, blob_name):
    from google.cloud import storage

    co = create_client_options("storage")
    storage_client = storage.Client(client_options=co)
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(blob_name)
    contents = None
    with tempfile.NamedTemporaryFile(mode="w+t") as tmp:
        blob.download_to_filename(tmp.name)
        with open(tmp.name, "r") as f:
            contents = f.read()
    return contents


def natural_sort(text):
    def atoi(text):
        return int(text) if text.isdigit() else text

    return [atoi(w) for w in re.split(r"(\d+)", text)]


# TODO: replace with to_hostlist_fast
def to_hostlist(nodenames) -> str:
    """make hostlist from list of node names"""
    # use tmp file because list could be large
    tmp_file = tempfile.NamedTemporaryFile(mode="w+t", delete=False)
    tmp_file.writelines("\n".join(sorted(nodenames, key=natural_sort)))
    tmp_file.close()

    hostlist = run(f"{lkp.scontrol} show hostlist {tmp_file.name}").stdout.rstrip()
    log_hostlists.debug(f"hostlist({len(nodenames)}): {hostlist}".format(hostlist))
    os.remove(tmp_file.name)
    return hostlist


def to_hostlist_fast(names: Iterable[str]) -> str:
    """
    Fast implementation of to_hostlist that doesn't invoke `scontrol`
    IMPORTANT:
    * Acts as `scontrol show hostlistsorted`, i.e. original order is not preserved
    * Achieves worse compression than `to_hostlist` for some cases
    """
    pref = defaultdict(list)
    tokenizer = re.compile(r"^(.*?)(\d*)$")
    for name in filter(None, names):
        p, s = tokenizer.match(name).groups()
        pref[p].append(s)

    def _compress_suffixes(ss: List[str]) -> List[str]:
        cur, res = None, []

        def cur_repr():
            nums, strs = cur
            if nums[0] == nums[1]:
                return strs[0]
            return f"{strs[0]}-{strs[1]}"

        for s in sorted(ss, key=int):
            n = int(s)
            if cur is None:
                cur = ((n, n), (s, s))
                continue

            nums, strs = cur
            if n == nums[1] + 1:
                cur = ((nums[0], n), (strs[0], s))
            else:
                res.append(cur_repr())
                cur = ((n, n), (s, s))
        if cur:
            res.append(cur_repr())
        return res

    res = []
    for p in sorted(pref.keys()):
        sl = defaultdict(list)
        for s in pref[p]:
            sl[len(s)].append(s)
        cs = []
        for ln in sorted(sl.keys()):
            if ln == 0:
                res.append(p)
            else:
                cs.extend(_compress_suffixes(sl[ln]))
        if not cs:
            continue
        if len(cs) == 1 and "-" not in cs[0]:
            res.append(f"{p}{cs[0]}")
        else:
            res.append(f"{p}[{','.join(cs)}]")
    return ",".join(res)


def part_is_tpu(part):
    """check if partition with name part contains a nodeset of type tpu"""
    return len(lkp.cfg.partitions[part].partition_nodeset_tpu) > 0


def get_vmcount_of_tpu_part(part):
    res = 0
    for ns in lkp.cfg.partitions[part].partition_nodeset_tpu:
        tpu_obj = TPU(lkp.cfg.nodeset_tpu[ns])
        if res == 0:
            res = tpu_obj.vmcount
        else:
            if res != tpu_obj.vmcount:
                # this should not happen, that in the same partition there are different vmcount nodesets
                return -1
    return res


def to_hostnames(nodelist: str) -> List[str]:
    """make list of hostnames from hostlist expression"""
    if not nodelist:
        return []  # avoid degenerate invocation of scontrol
    if isinstance(nodelist, str):
        hostlist = nodelist
    else:
        hostlist = ",".join(nodelist)
    hostnames = run(f"{lkp.scontrol} show hostnames {hostlist}").stdout.splitlines()
    log_hostlists.debug(f"hostnames({len(hostnames)}) from {hostlist}")
    return hostnames


def retry_exception(exc):
    """return true for exceptions that should always be retried"""
    retry_errors = (
        "Rate Limit Exceeded",
        "Quota Exceeded",
    )
    return any(e in str(exc) for e in retry_errors)


def ensure_execute(request):
    """Handle rate limits and socket time outs"""

    for retry, wait in enumerate(backoff_delay(0.5, timeout=10 * 60, count=20)):
        try:
            return request.execute()
        except googleapiclient.errors.HttpError as e:
            if retry_exception(e):
                log.error(f"retry:{retry} '{e}'")
                sleep(wait)
                continue
            raise

        except socket.timeout as e:
            # socket timed out, try again
            log.debug(e)

        except Exception as e:
            log.error(e, exc_info=True)
            raise

        break


def batch_execute(requests, retry_cb=None, log_err=log.error):
    """execute list or dict<req_id, request> as batch requests
    retry if retry_cb returns true
    """

    compute = globals()["compute"]
    BATCH_LIMIT = 1000
    if not isinstance(requests, dict):
        requests = {str(k): v for k, v in enumerate(requests)}  # rid generated here
    done = {}
    failed = {}
    timestamps = []
    rate_limited = False

    def batch_callback(rid, resp, exc):
        nonlocal rate_limited
        if exc is not None:
            log_err(f"compute request exception {rid}: {exc}")
            if retry_exception(exc):
                rate_limited = True
            else:
                req = requests.pop(rid)
                failed[rid] = (req, exc)
        else:
            # if retry_cb is set, don't move to done until it returns false
            if retry_cb is None or not retry_cb(resp):
                requests.pop(rid)
                done[rid] = resp

    def batch_request(reqs):
        batch = compute.new_batch_http_request(callback=batch_callback)
        for rid, req in reqs:
            batch.add(req, request_id=rid)
        return batch

    while requests:
        if timestamps:
            timestamps = [stamp for stamp in timestamps if stamp > time()]
        if rate_limited and timestamps:
            stamp = next(iter(timestamps))
            sleep(max(stamp - time(), 0))
            rate_limited = False
        # up to API_REQ_LIMIT (2000) requests
        # in chunks of up to BATCH_LIMIT (1000)
        batches = [
            batch_request(chunk)
            for chunk in chunked(islice(requests.items(), API_REQ_LIMIT), BATCH_LIMIT)
        ]
        timestamps.append(time() + 100)
        with ThreadPoolExecutor() as exe:
            futures = []
            for batch in batches:
                future = exe.submit(ensure_execute, batch)
                futures.append(future)
            for future in futures:
                result = future.exception()
                if result is not None:
                    raise result

    return done, failed


def wait_request(operation, project=None, compute=None):
    """makes the appropriate wait request for a given operation"""
    if not compute:
        compute = globals()["compute"]
    if project is None:
        project = lkp.project
    if "zone" in operation:
        req = compute.zoneOperations().wait(
            project=project,
            zone=trim_self_link(operation["zone"]),
            operation=operation["name"],
        )
    elif "region" in operation:
        req = compute.regionOperations().wait(
            project=project,
            region=trim_self_link(operation["region"]),
            operation=operation["name"],
        )
    else:
        req = compute.globalOperations().wait(
            project=project, operation=operation["name"]
        )
    return req


def wait_for_operation(operation, project=None, compute=None):
    """wait for given operation"""
    if not compute:
        compute = globals()["compute"]
    if project is None:
        project = parse_self_link(operation["selfLink"]).project
    wait_req = wait_request(operation, project=project, compute=compute)

    while True:
        result = ensure_execute(wait_req)
        if result["status"] == "DONE":
            log_errors = " with errors" if "error" in result else ""
            log.debug(
                f"operation complete{log_errors}: type={result['operationType']}, name={result['name']}"
            )
            return result


def wait_for_operations(operations, project=None, compute=None):
    if not compute:
        compute = globals()["compute"]
    return [
        wait_for_operation(op, project=project, compute=compute) for op in operations
    ]


def get_filtered_operations(
    op_filter,
    zone=None,
    region=None,
    only_global=False,
    project=None,
    compute=None,
):
    """get list of operations associated with group id"""

    if not compute:
        compute = globals()["compute"]
    if project is None:
        project = lkp.project
    operations = []

    def get_aggregated_operations(items):
        # items is a dict of location key to value: dict(operations=<list of operations>) or an empty dict
        operations.extend(
            chain.from_iterable(
                ops["operations"] for ops in items.values() if "operations" in ops
            )
        )

    def get_list_operations(items):
        operations.extend(items)

    handle_items = get_list_operations
    if only_global:
        act = compute.globalOperations()
        op = act.list(project=project, filter=op_filter)
        nxt = act.list_next
    elif zone is not None:
        act = compute.zoneOperations()
        op = act.list(project=project, zone=zone, filter=op_filter)
        nxt = act.list_next
    elif region is not None:
        act = compute.regionOperations()
        op = act.list(project=project, region=region, filter=op_filter)
        nxt = act.list_next
    else:
        act = compute.globalOperations()
        op = act.aggregatedList(
            project=project, filter=op_filter, fields="items.*.operations,nextPageToken"
        )
        nxt = act.aggregatedList_next
        handle_items = get_aggregated_operations
    while op is not None:
        result = ensure_execute(op)
        handle_items(result["items"])
        op = nxt(op, result)
    return operations


def get_insert_operations(group_ids, flt=None, project=None, compute=None):
    """get all insert operations from a list of operationGroupId"""
    if not compute:
        compute = globals()["compute"]
    if project is None:
        project = lkp.project
    if isinstance(group_ids, str):
        group_ids = group_ids.split(",")
    filters = [
        "operationType=insert",
        flt,
        " OR ".join(f"(operationGroupId={id})" for id in group_ids),
    ]
    return get_filtered_operations(" AND ".join(f"({f})" for f in filters if f))


def machine_type_sockets(template):
    pattern = re.compile("^(?P<family>[^-]+)")
    m = pattern.match(template.machineType)
    if not m:
        raise Exception(f"template {template} does not match expected regex")
    family = m.group("family")
    guestCpus: int = int(template.machine_info.guestCpus)
    socket_count = dict.get(
        {
            "h3": 2,
            "c2d": 2 if guestCpus > 56 else 1,
            "a3": 2,
        },
        family,
        1,  # assume 1 socket for all other families
    )
    return socket_count


def isSmt(template):
    machineType: str = template.machineType
    guestCpus: int = int(template.machine_info.guestCpus)

    pattern = re.compile("^(?P<family>[^-]+)")
    matches = pattern.match(machineType)
    machineTypeFamily: str = matches["family"]

    # https://cloud.google.com/compute/docs/cpu-platforms
    noSmtFamily = [
        "t2a",
        "t2d",
        "h3",
    ]
    if machineTypeFamily in noSmtFamily:
        return False
    elif guestCpus == 1:
        return False
    return True


def getThreadsPerCore(template):
    threadsPerCore: int = template.advancedMachineFeatures.threadsPerCore

    if not isSmt(template):
        return 1
    elif threadsPerCore:
        return threadsPerCore
    else:
        return 2


@retry(
    max_retries=9,
    init_wait_time=1,
    warn_msg="Temporary failure in name resolution",
    exc_type=socket.gaierror,
)
def host_lookup(host_name: str) -> str:
    return socket.gethostbyname(host_name)


class Dumper(yaml.SafeDumper):
    """Add representers for pathlib.Path and NSDict for yaml serialization"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.add_representer(NSDict, self.represent_nsdict)
        self.add_multi_representer(Path, self.represent_path)

    @staticmethod
    def represent_nsdict(dumper, data):
        return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())

    @staticmethod
    def represent_path(dumper, path):
        return dumper.represent_scalar("tag:yaml.org,2002:str", str(path))


class TPU:
    """Class for handling the TPU-vm nodes"""

    if can_tpu:
        State = tpu.types.cloud_tpu.Node.State
        TPUS_PER_VM = 4
        __expected_states = {
            "create": State.READY,
            "start": State.READY,
            "stop": State.STOPPED,
        }

        __tpu_version_mapping = {
            "V2": tpu.AcceleratorConfig().Type.V2,
            "V3": tpu.AcceleratorConfig().Type.V3,
            "V4": tpu.AcceleratorConfig().Type.V4,
        }

    def __init__(self, nodeset):
        if not can_tpu:
            raise Exception("TPU pip package not installed")
        self._nodeset = nodeset
        self._parent = f"projects/{lkp.project}/locations/{nodeset.zone}"
        co = create_client_options(ApiEndpoint.TPU)
        self._client = tpu.TpuClient(client_options=co)
        self.data_disks = []
        for data_disk in nodeset.data_disks:
            ad = tpu.AttachedDisk()
            ad.source_disk = data_disk
            ad.mode = tpu.AttachedDisk.DiskMode.DISK_MODE_UNSPECIFIED
            self.data_disks.append(ad)
        ns_ac = nodeset.accelerator_config
        if ns_ac.topology != "" and ns_ac.version != "":
            ac = tpu.AcceleratorConfig()
            ac.topology = ns_ac.topology
            ac.type_ = self.__tpu_version_mapping[ns_ac.version]
            self.ac = ac
        else:
            req = tpu.GetAcceleratorTypeRequest(
                name=f"{self._parent}/acceleratorTypes/{nodeset.node_type}"
            )
            self.ac = self._client.get_accelerator_type(req).accelerator_configs[0]
        self.vmcount = self.__calc_vm_from_topology(self.ac.topology)

    @property
    def nodeset(self):
        return self._nodeset

    @property
    def preserve_tpu(self):
        return self._nodeset.preserve_tpu

    @property
    def node_type(self):
        return self._nodeset.node_type

    @property
    def tf_version(self):
        return self._nodeset.tf_version

    @property
    def enable_public_ip(self):
        return self._nodeset.enable_public_ip

    @property
    def preemptible(self):
        return self._nodeset.preemptible

    @property
    def reserved(self):
        return self._nodeset.reserved

    @property
    def service_account(self):
        return self._nodeset.service_account

    @property
    def zone(self):
        return self._nodeset.zone

    def check_node_type(self):
        if self.node_type is None:
            return False
        try:
            request = tpu.GetAcceleratorTypeRequest(
                name=f"{self._parent}/acceleratorTypes/{self.node_type}"
            )
            return self._client.get_accelerator_type(request=request) is not None
        except Exception:
            return False

    def check_tf_version(self):
        try:
            request = tpu.GetRuntimeVersionRequest(
                name=f"{self._parent}/runtimeVersions/{self.tf_version}"
            )
            return self._client.get_runtime_version(request=request) is not None
        except Exception:
            return False

    def __calc_vm_from_topology(self, topology):
        topo = topology.split("x")
        tot = 1
        for num in topo:
            tot = tot * int(num)
        return tot // self.TPUS_PER_VM

    def __check_resp(self, response, op_name):
        des_state = self.__expected_states.get(op_name)
        # If the state is not in the table just print the response
        if des_state is None:
            return False
        if response.__class__.__name__ != "Node":  # If the response is not a node fail
            return False
        if response.state == des_state:
            return True
        return False

    def list_nodes(self):
        try:
            request = tpu.ListNodesRequest(parent=self._parent)
            res = self._client.list_nodes(request=request)
        except gExceptions.NotFound:
            res = None
        return res

    def list_node_names(self):
        return [node.name.split("/")[-1] for node in self.list_nodes()]

    def start_node(self, nodename):
        request = tpu.StartNodeRequest(name=f"{self._parent}/nodes/{nodename}")
        resp = self._client.start_node(request=request).result()
        return self.__check_resp(resp, "start")

    def stop_node(self, nodename):
        request = tpu.StopNodeRequest(name=f"{self._parent}/nodes/{nodename}")
        resp = self._client.stop_node(request=request).result()
        return self.__check_resp(resp, "stop")

    def get_node(self, nodename):
        try:
            request = tpu.GetNodeRequest(name=f"{self._parent}/nodes/{nodename}")
            res = self._client.get_node(request=request)
        except gExceptions.NotFound:
            res = None
        return res

    def _register_node(self, nodename, ip_addr):
        dns_name = socket.getnameinfo((ip_addr, 0), 0)[0]
        run(
            f"{lkp.scontrol} update nodename={nodename} nodeaddr={ip_addr} nodehostname={dns_name}"
        )

    def create_node(self, nodename):
        if self.vmcount > 1 and not isinstance(nodename, list):
            log.error(
                f"Tried to create a {self.vmcount} node TPU on nodeset {self._nodeset.nodeset_name} but only received one nodename {nodename}"
            )
            return False
        if self.vmcount > 1 and (
            isinstance(nodename, list) and len(nodename) != self.vmcount
        ):
            log.error(
                f"Expected to receive a list of {self.vmcount} nodenames for TPU node creation in nodeset {self._nodeset.nodeset_name}, but received this list {nodename}"
            )
            return False

        node = tpu.Node()
        node.accelerator_config = self.ac
        node.runtime_version = f"tpu-vm-tf-{self.tf_version}"
        startup_script = """
        #!/bin/bash
        echo "startup script not found > /var/log/startup_error.log"
        """
        with open(
            Path(cfg.slurm_scripts_dir or dirs.scripts) / "startup.sh", "r"
        ) as script:
            startup_script = script.read()
        if isinstance(nodename, list):
            node_id = nodename[0]
            slurm_names = []
            wid = 0
            for node_wid in nodename:
                slurm_names.append(f"WORKER_{wid}:{node_wid}")
                wid += 1
        else:
            node_id = nodename
            slurm_names = [f"WORKER_0:{nodename}"]
        node.metadata = {
            "slurm_docker_image": self.nodeset.docker_image,
            "startup-script": startup_script,
            "slurm_instance_role": "compute",
            "slurm_cluster_name": lkp.cfg.slurm_cluster_name,
            "slurm_bucket_path": lkp.cfg.bucket_path,
            "slurm_names": ";".join(slurm_names),
            "universe_domain": universe_domain(),
        }
        node.tags = [lkp.cfg.slurm_cluster_name]
        if self.nodeset.service_account:
            node.service_account.email = self.nodeset.service_account.email
            node.service_account.scope = self.nodeset.service_account.scopes
        node.scheduling_config.preemptible = self.preemptible
        node.scheduling_config.reserved = self.reserved
        node.network_config.subnetwork = self.nodeset.subnetwork
        node.network_config.enable_external_ips = self.enable_public_ip
        if self.data_disks:
            node.data_disks = self.data_disks

        request = tpu.CreateNodeRequest(parent=self._parent, node=node, node_id=node_id)
        resp = self._client.create_node(request=request).result()
        if not self.__check_resp(resp, "create"):
            return False
        if isinstance(nodename, list):
            for node_id, net_endpoint in zip(nodename, resp.network_endpoints):
                self._register_node(node_id, net_endpoint.ip_address)
        else:
            ip_add = resp.network_endpoints[0].ip_address
            self._register_node(nodename, ip_add)
        return True

    def delete_node(self, nodename):
        request = tpu.DeleteNodeRequest(name=f"{self._parent}/nodes/{nodename}")
        try:
            resp = self._client.delete_node(request=request).result()
            if resp:
                return self.get_node(nodename=nodename) is None
            return False
        except gExceptions.NotFound:
            # log only error if vmcount is 1 as for other tpu vm count, this could be "phantom" nodes
            if self.vmcount == 1:
                log.error(f"Tpu single node {nodename} not found")
            else:
                # for the TPU nodes that consist in more than one vm, only the first node of the TPU a.k.a. the master node will
                # exist as real TPU nodes, so the other ones are expected to not be found, check the hostname of the node that has
                # not been found, and if it ends in 0, it means that is the master node and it should have been found, and in consequence
                # log an error
                nodehostname = yaml.safe_load(
                    run(f"{lkp.scontrol} --yaml show node {nodename}").stdout.rstrip()
                )["nodes"][0]["hostname"]
                if nodehostname.split("-")[-1] == "0":
                    log.error(f"TPU master node {nodename} not found")
                else:
                    log.info(f"Deleted TPU 'phantom' node {nodename}")
            # If the node is not found it is tecnichally deleted, so return success.
            return True


class Lookup:
    """Wrapper class for cached data access"""

    def __init__(self, cfg=None):
        self._cfg = cfg or NSDict()
        self.template_cache_path = Path(__file__).parent / "template_info.cache"

    @property
    def cfg(self):
        return self._cfg

    @property
    def project(self):
        return self.cfg.project or authentication_project()

    @property
    def control_addr(self):
        return self.cfg.slurm_control_addr

    @property
    def control_host(self):
        return self.cfg.slurm_control_host

    @cached_property
    def control_host_addr(self):
        return host_lookup(self.cfg.slurm_control_host)

    @property
    def control_host_port(self):
        return self.cfg.slurm_control_host_port

    @property
    def endpoint_versions(self):
        return self.cfg.endpoint_versions

    @property
    def scontrol(self):
        return Path(self.cfg.slurm_bin_dir if cfg else "") / "scontrol"

    @cached_property
    def instance_role(self):
        return instance_metadata("attributes/slurm_instance_role")

    @cached_property
    def instance_role_safe(self):
        try:
            role = self.instance_role
        except Exception as e:
            log.error(e)
            role = None
        return role

    @cached_property
    def compute(self):
        # TODO evaluate when we need to use google_app_cred_path
        if self.cfg.google_app_cred_path:
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg.google_app_cred_path
        return compute_service()

    @cached_property
    def hostname(self):
        return socket.gethostname()

    @cached_property
    def hostname_fqdn(self):
        return socket.getfqdn()

    @cached_property
    def zone(self):
        return instance_metadata("zone")

    node_desc_regex = re.compile(
        r"^(?P<prefix>(?P<cluster>[^\s\-]+)-(?P<nodeset>\S+))-(?P<node>(?P<suffix>\w+)|(?P<range>\[[\d,-]+\]))$"
    )

    @lru_cache(maxsize=None)
    def _node_desc(self, node_name):
        """Get parts from node name"""
        if not node_name:
            node_name = self.hostname
        # workaround below is for VMs whose hostname is FQDN
        node_name_short = node_name.split(".")[0]
        m = self.node_desc_regex.match(node_name_short)
        if not m:
            raise Exception(f"node name {node_name} is not valid")
        return m.groupdict()

    def node_prefix(self, node_name=None):
        return self._node_desc(node_name)["prefix"]

    def node_nodeset_name(self, node_name=None):
        return self._node_desc(node_name)["nodeset"]

    def node_nodeset(self, node_name=None):
        nodeset_name = self.node_nodeset_name(node_name)
        ns = self.cfg.nodeset.get(nodeset_name)
        if ns:
            return ns
        return self.cfg.nodeset_tpu.get(nodeset_name)

    def node_is_tpu(self, node_name=None):
        nodeset_name = self.node_nodeset_name(node_name)
        return self.cfg.nodeset_tpu.get(nodeset_name) is not None

    def node_is_dyn(self, node_name=None) -> bool:
        nodeset = self.node_nodeset_name(node_name)
        return self.cfg.nodeset_dyn.get(nodeset) is not None

    def chunk_tpu_nodes(self, tpu_nodes):
        model = tpu_nodes[0]
        tpu = TPU(self.node_nodeset(model))
        return chunked(tpu_nodes, n=tpu.vmcount)

    def node_template(self, node_name=None):
        return self.node_nodeset(node_name).instance_template

    def node_template_info(self, node_name=None):
        return self.template_info(self.node_template(node_name))

    def node_region(self, node_name=None):
        nodeset = self.node_nodeset(node_name)
        return parse_self_link(nodeset.subnetwork).region

    def nodeset_prefix(self, nodeset_name):
        return f"{self.cfg.slurm_cluster_name}-{nodeset_name}"

    def nodelist_range(self, nodeset_name: str, start: int, count: int) -> str:
        assert 0 <= start and 0 < count
        pref = self.nodeset_prefix(nodeset_name)
        if count == 1:
            return f"{pref}-{start}"
        return f"{pref}-[{start}-{start + count - 1}]"

    def static_dynamic_sizes(self, nodeset: object) -> int:
        return (nodeset.node_count_static or 0, nodeset.node_count_dynamic_max or 0)

    def nodelist(self, nodeset) -> str:
        cnt = sum(self.static_dynamic_sizes(nodeset))
        if cnt == 0:
            return ""
        return self.nodelist_range(nodeset.nodeset_name, 0, cnt)

    def nodenames(self, nodeset) -> Tuple[Iterable[str], Iterable[str]]:
        pref = self.nodeset_prefix(nodeset.nodeset_name)
        s_count, d_count = self.static_dynamic_sizes(nodeset)
        return (
            (f"{pref}-{i}" for i in range(s_count)),
            (f"{pref}-{i}" for i in range(s_count, s_count + d_count)),
        )

    def power_managed_nodesets(self) -> Iterable[object]:
        return chain(self.cfg.nodeset.values(), self.cfg.nodeset_tpu.values())

    def is_power_managed_node(self, node_name: str) -> bool:
        try:
            ns = self.node_nodeset(node_name)
            if ns is None:
                return False
            idx = int(self._node_desc(node_name)["suffix"])
            return idx < sum(self.static_dynamic_sizes(ns))
        except Exception:
            return False

    def is_static_node(self, node_name: str) -> bool:
        if not self.is_power_managed_node(node_name):
            return False
        idx = int(self._node_desc(node_name)["suffix"])
        return idx < self.node_nodeset(node_name).node_count_static

    @lru_cache(maxsize=None)
    def slurm_nodes(self):
        StateTuple = namedtuple("StateTuple", "base,flags")

        def make_node_tuple(node_line):
            """turn node,state line to (node, StateTuple(state))"""
            # state flags include: CLOUD, COMPLETING, DRAIN, FAIL, POWERED_DOWN,
            #   POWERING_DOWN
            node, fullstate = node_line.split(",")
            state = fullstate.split("+")
            state_tuple = StateTuple(state[0], set(state[1:]))
            return (node, state_tuple)

        cmd = (
            f"{self.scontrol} show nodes | "
            r"grep -oP '^NodeName=\K(\S+)|\s+State=\K(\S+)' | "
            r"paste -sd',\n'"
        )
        node_lines = run(cmd, shell=True).stdout.rstrip().splitlines()
        nodes = {
            node: state
            for node, state in map(make_node_tuple, node_lines)
            if "CLOUD" in state.flags or "DYNAMIC_NORM" in state.flags
        }
        return nodes

    def slurm_node(self, nodename):
        return self.slurm_nodes().get(nodename)

    @lru_cache(maxsize=1)
    def instances(self, project=None, slurm_cluster_name=None):
        slurm_cluster_name = slurm_cluster_name or self.cfg.slurm_cluster_name
        project = project or self.project
        instance_information_fields = [
            "advancedMachineFeatures",
            "cpuPlatform",
            "creationTimestamp",
            "disks",
            "disks",
            "fingerprint",
            "guestAccelerators",
            "hostname",
            "id",
            "kind",
            "labelFingerprint",
            "labels",
            "lastStartTimestamp",
            "lastStopTimestamp",
            "lastSuspendedTimestamp",
            "machineType",
            "metadata",
            "name",
            "networkInterfaces",
            "resourceStatus",
            "scheduling",
            "selfLink",
            "serviceAccounts",
            "shieldedInstanceConfig",
            "shieldedInstanceIntegrityPolicy",
            "sourceMachineImage",
            "status",
            "statusMessage",
            "tags",
            "zone",
            # "deletionProtection",
            # "startRestricted",
        ]
        if lkp.cfg.enable_slurm_gcp_plugins:
            slurm_gcp_plugins.register_instance_information_fields(
                lkp=lkp,
                project=project,
                slurm_cluster_name=slurm_cluster_name,
                instance_information_fields=instance_information_fields,
            )
        instance_information_fields = sorted(set(instance_information_fields))
        instance_fields = ",".join(instance_information_fields)
        fields = f"items.zones.instances({instance_fields}),nextPageToken"
        flt = f"labels.slurm_cluster_name={slurm_cluster_name} AND name:{slurm_cluster_name}-*"
        act = self.compute.instances()
        op = act.aggregatedList(project=project, fields=fields, filter=flt)

        def properties(inst):
            """change instance properties to a preferred format"""
            inst["zone"] = trim_self_link(inst["zone"])
            inst["machineType"] = trim_self_link(inst["machineType"])
            # metadata is fetched as a dict of dicts like:
            # {'key': key, 'value': value}, kinda silly
            metadata = {i["key"]: i["value"] for i in inst["metadata"].get("items", [])}
            if "slurm_instance_role" not in metadata:
                return None
            inst["role"] = metadata["slurm_instance_role"]
            inst["metadata"] = metadata
            # del inst["metadata"]  # no need to store all the metadata
            return NSDict(inst)

        instances = {}
        while op is not None:
            result = ensure_execute(op)
            instance_iter = (
                (inst["name"], properties(inst))
                for inst in chain.from_iterable(
                    m["instances"] for m in result.get("items", {}).values()
                )
            )
            instances.update(
                {name: props for name, props in instance_iter if props is not None}
            )
            op = act.aggregatedList_next(op, result)
        return instances

    def instance(self, instance_name, project=None, slurm_cluster_name=None):
        instances = self.instances(
            project=project, slurm_cluster_name=slurm_cluster_name
        )
        return instances.get(instance_name)

    @lru_cache()
    def reservation(self, name: str, zone: str) -> object:
        """See https://cloud.google.com/compute/docs/reference/rest/v1/reservations"""
        try:
            _, project, _, short_name = name.split("/")
        except ValueError:
            raise ValueError(
                f"Invalid reservation name: '{name}', expected format is 'projects/PROJECT/reservations/NAME'"
            )

        return (
            self.compute.reservations()
            .get(project=project, zone=zone, reservation=short_name)
            .execute()
        )

    @lru_cache(maxsize=1)
    def machine_types(self, project=None):
        project = project or self.project
        field_names = "name,zone,guestCpus,memoryMb,accelerators"
        fields = f"items.zones.machineTypes({field_names}),nextPageToken"

        machines = defaultdict(dict)
        act = self.compute.machineTypes()
        op = act.aggregatedList(project=project, fields=fields)
        while op is not None:
            result = ensure_execute(op)
            machine_iter = chain.from_iterable(
                m["machineTypes"]
                for m in result["items"].values()
                if "machineTypes" in m
            )
            for machine in machine_iter:
                name = machine["name"]
                zone = machine["zone"]
                machines[name][zone] = machine

            op = act.aggregatedList_next(op, result)
        return machines

    def machine_type(self, machine_type, project=None, zone=None):
        """ """
        custom_patt = re.compile(
            r"((?P<family>\w+)-)?custom-(?P<cpus>\d+)-(?P<mem>\d+)"
        )
        custom_match = custom_patt.match(machine_type)
        if zone:
            project = project or self.project
            machine_info = ensure_execute(
                self.compute.machineTypes().get(
                    project=project, zone=zone, machineType=machine_type
                )
            )
        elif custom_match is not None:
            groups = custom_match.groupdict()
            cpus, mem = (groups[k] for k in ["cpus", "mem"])
            machine_info = {
                "guestCpus": int(cpus),
                "memoryMb": int(mem),
            }
        else:
            machines = self.machine_types(project=project)
            machine_info = next(iter(machines[machine_type].values()), None)
            if machine_info is None:
                raise Exception(f"machine type {machine_type} not found")
        return NSDict(machine_info)

    def template_machine_conf(self, template_link, project=None, zone=None):
        template = self.template_info(template_link)
        if not template.machineType:
            temp_name = trim_self_link(template_link)
            raise Exception(f"instance template {temp_name} has no machine type")
        template.machine_info = self.machine_type(template.machineType, zone=zone)
        machine = template.machine_info

        machine_conf = NSDict()
        machine_conf.boards = 1  # No information, assume 1
        machine_conf.sockets = machine_type_sockets(template)
        # the value below for SocketsPerBoard must be type int
        machine_conf.sockets_per_board = machine_conf.sockets // machine_conf.boards
        machine_conf.threads_per_core = 1
        _div = 2 if getThreadsPerCore(template) == 1 else 1
        machine_conf.cpus = (
            int(machine.guestCpus / _div) if isSmt(template) else machine.guestCpus
        )
        machine_conf.cores_per_socket = int(machine_conf.cpus / machine_conf.sockets)
        # Because the actual memory on the host will be different than
        # what is configured (e.g. kernel will take it). From
        # experiments, about 16 MB per GB are used (plus about 400 MB
        # buffer for the first couple of GB's. Using 30 MB to be safe.
        gb = machine.memoryMb // 1024
        machine_conf.memory = machine.memoryMb - (400 + (30 * gb))
        return machine_conf

    @contextmanager
    def template_cache(self, writeback=False):
        flag = "c" if writeback else "r"
        err = None
        for wait in backoff_delay(0.125, timeout=60, count=20):
            try:
                cache = shelve.open(
                    str(self.template_cache_path), flag=flag, writeback=writeback
                )
                break
            except OSError as e:
                err = e
                log.debug(f"Failed to access template info cache: {e}")
                sleep(wait)
                continue
        else:
            # reached max_count of waits
            raise Exception(f"Failed to access cache file. latest error: {err}")
        try:
            yield cache
        finally:
            cache.close()

    @lru_cache(maxsize=None)
    def template_info(self, template_link, project=None):
        project = project or self.project
        template_name = trim_self_link(template_link)
        # split read and write access to minimize write-lock. This might be a
        # bit slower? TODO measure
        if self.template_cache_path.exists():
            with self.template_cache() as cache:
                if template_name in cache:
                    return NSDict(cache[template_name])

        template = ensure_execute(
            self.compute.instanceTemplates().get(
                project=project, instanceTemplate=template_name
            )
        ).get("properties")
        template = NSDict(template)
        # name and link are not in properties, so stick them in
        template.name = template_name
        template.link = template_link
        # TODO delete metadata to reduce memory footprint?
        # del template.metadata

        # translate gpus into an easier-to-read format
        machine_info = self.machine_type(template.machineType, project=project)
        if machine_info.accelerators:
            template.gpu_type = machine_info.accelerators[0].guestAcceleratorType
            template.gpu_count = machine_info.accelerators[0].guestAcceleratorCount
        elif template.guestAccelerators:
            template.gpu_type = template.guestAccelerators[0].acceleratorType
            template.gpu_count = template.guestAccelerators[0].acceleratorCount
        else:
            template.gpu_type = None
            template.gpu_count = 0

        # keep write access open for minimum time
        with self.template_cache(writeback=True) as cache:
            cache[template_name] = template.to_dict()
        # cache should be owned by slurm
        chown_slurm(self.template_cache_path)

        return template

    def nodeset_map(self, hostnames: list):
        """Convert a list of nodes into a map of nodeset_name to hostnames"""
        nodeset_map = collections.defaultdict(list)
        for node in hostnames:
            nodeset_map[self.node_nodeset_name(node)].append(node)
        return nodeset_map


# Define late globals
lkp = Lookup()
cfg = load_config_file(CONFIG_FILE)
if not cfg:
    try:
        cfg = fetch_config_yaml()
    except Exception as e:
        log.warning(f"config not found in bucket: {e}")
    if cfg:
        save_config(cfg, CONFIG_FILE)

lkp = Lookup(cfg)

# Needs to be run after the lookup is complete to get endpoint versions
compute = compute_service()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument(
        "--partitions",
        "-p",
        help="The partition(s) to retrieve the TPU vmcount value for.",
    )
    args = parser.parse_args()
    if args.partitions:
        # useful exit code
        # partition does not exists in config.yaml, thus do not exist in slurm
        PART_INVALID = -1
        # in the same partition there are nodesets with different vmcounts
        DIFF_VMCOUNTS_SAME_PART = -2
        # partition is a list of partitions in which at least two of them have different vmcount
        DIFF_PART_DIFFERENT_VMCOUNTS = -3
        vmcounts = []
        # valid equals to 0 means that we are ok, otherwise it will be set to one of the previously defined exit codes
        valid = 0
        for part in args.partitions.split(","):
            if part not in lkp.cfg.partitions:
                valid = PART_INVALID
                break
            else:
                if part_is_tpu(part):
                    vmcount = get_vmcount_of_tpu_part(part)
                    if vmcount == -1:
                        valid = DIFF_VMCOUNTS_SAME_PART
                        break
                    vmcounts.append(vmcount)
                else:
                    vmcounts.append(0)
        # this means that there are different vmcounts for these partitions
        if valid == 0 and len(set(vmcounts)) != 1:
            valid = DIFF_PART_DIFFERENT_VMCOUNTS
        if valid != 0:
            print(f"VMCOUNT:{valid}")
        else:
            print(f"VMCOUNT:{vmcounts[0]}")
