scripts/suspend.py (134 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (C) SchedMD LLC. # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List import argparse import logging import sys from pathlib import Path import util from util import ( groupby_unsorted, log_api_request, batch_execute, to_hostlist_fast, wait_for_operations, separate, execute_with_futures, ) from util import lkp, cfg, compute, TPU import slurm_gcp_plugins filename = Path(__file__).name LOGFILE = (Path(cfg.slurm_log_dir if cfg else ".") / filename).with_suffix(".log") log = logging.getLogger(filename) TOT_REQ_CNT = 1000 def truncate_iter(iterable, max_count): end = "..." _iter = iter(iterable) for i, el in enumerate(_iter, start=1): if i >= max_count: yield end break yield el def delete_instance_request(instance, project=None, zone=None): project = project or lkp.project request = compute.instances().delete( project=project, zone=(zone or lkp.instance(instance).zone), instance=instance, ) log_api_request(request) return request def stop_tpu(data): tpu_nodeset = data["nodeset"] node = data["node"] tpu = data["tpu"] if tpu_nodeset.preserve_tpu and tpu.vmcount == 1: log.info(f"stopping node {node}") if tpu.stop_node(node): return log.error("Error stopping node {node} will delete instead") log.info(f"deleting node {node}") if not tpu.delete_node(node): log.error("Error deleting node {node}") def delete_tpu_instances(instances): stop_data = [] for prefix, nodes in util.groupby_unsorted(instances, lkp.node_prefix): log.info(f"Deleting TPU nodes from prefix {prefix}") lnodes = list(nodes) tpu_nodeset = lkp.node_nodeset(lnodes[0]) tpu = TPU(tpu_nodeset) stop_data.extend( [{"tpu": tpu, "node": node, "nodeset": tpu_nodeset} for node in lnodes] ) execute_with_futures(stop_tpu, stop_data) def delete_instances(instances): """delete instances individually""" invalid, valid = separate(lambda inst: bool(lkp.instance(inst)), instances) if len(invalid) > 0: log.debug("instances do not exist: {}".format(",".join(invalid))) if len(valid) == 0: log.debug("No instances to delete") return requests = {inst: delete_instance_request(inst) for inst in valid} log.info(f"delete {len(valid)} instances ({to_hostlist_fast(valid)})") done, failed = batch_execute(requests) if failed: for err, nodes in groupby_unsorted(lambda n: failed[n][1], failed.keys()): log.error(f"instances failed to delete: {err} ({to_hostlist_fast(nodes)})") wait_for_operations(done.values()) # TODO do we need to check each operation for success? That is a lot more API calls log.info(f"deleted {len(done)} instances {to_hostlist_fast(done.keys())}") def suspend_nodes(nodes: List[str]) -> None: tpu_nodes, other_nodes = [], [] for node in nodes[:]: if lkp.node_is_tpu(node): tpu_nodes.append(node) else: other_nodes.append(node) delete_instances(other_nodes) delete_tpu_instances(tpu_nodes) def main(nodelist): """main called when run as script""" log.debug(f"SuspendProgram {nodelist}") # Filter out nodes not in config.yaml other_nodes, pm_nodes = separate( lkp.is_power_managed_node, util.to_hostnames(nodelist) ) if other_nodes: log.debug( f"Ignoring non-power-managed nodes '{to_hostlist_fast(other_nodes)}' from '{nodelist}'" ) if pm_nodes: log.debug(f"Suspending nodes '{to_hostlist_fast(pm_nodes)}' from '{nodelist}'") else: log.debug("No cloud nodes to suspend") return log.info(f"suspend {nodelist}") if lkp.cfg.enable_slurm_gcp_plugins: slurm_gcp_plugins.pre_main_suspend_nodes(lkp=lkp, nodelist=nodelist) suspend_nodes(pm_nodes) parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) parser.add_argument("nodelist", help="list of nodes to suspend") parser.add_argument( "--debug", "-d", dest="loglevel", action="store_const", const=logging.DEBUG, default=logging.INFO, help="Enable debugging output", ) parser.add_argument( "--trace-api", "-t", action="store_true", help="Enable detailed api request output", ) if __name__ == "__main__": args = parser.parse_args() if cfg.enable_debug_logging: args.loglevel = logging.DEBUG if args.trace_api: cfg.extra_logging_flags = list(cfg.extra_logging_flags) cfg.extra_logging_flags.append("trace_api") util.chown_slurm(LOGFILE, mode=0o600) util.config_root_logger(filename, level=args.loglevel, logfile=LOGFILE) log = logging.getLogger(Path(__file__).name) sys.excepthook = util.handle_exception main(args.nodelist)