integration/preemption.py (78 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # """Preemption tests, need to be run on a an actual cluster""" import logging import shutil import subprocess import time from datetime import datetime from pathlib import Path import submitit from submitit import AutoExecutor, Job from submitit.core import test_core FILE = Path(__file__) LOGS = FILE.parent / "logs" / f"{FILE.stem}_log" log = logging.getLogger("preemption_main") formatter = logging.Formatter("%(name)s %(levelname)s (%(asctime)s) - %(message)s") handler = logging.StreamHandler() handler.setFormatter(formatter) log.setLevel(logging.INFO) log.addHandler(handler) def clock(partition: str, duration: int): log = logging.getLogger(f"preemption_{partition}") tick_tack = ["tick", "tack"] try: for minute in range(duration - 5): log.info(tick_tack[minute % 2]) time.sleep(60) logging.warning("*** Exited peacefully ***") return duration except: logging.warning(f"!!! Interrupted on: {datetime.now().isoformat()}") raise def pascal_job(partition: str, timeout_min: int, node: str = "") -> Job: """Submit a job with specific constraint that we can preempt deterministically.""" ex = submitit.AutoExecutor(folder=LOGS, slurm_max_num_timeout=1) ex.update_parameters( name=f"submitit_preemption_{partition}", timeout_min=timeout_min, mem_gb=7, slurm_constraint="pascal", slurm_comment="submitit integration test", slurm_partition=partition, # pascal nodes have 80 cpus. # By requesting 50 we now that their can be only one such job with this property. cpus_per_task=50, slurm_additional_parameters={}, ) if node: ex.update_parameters(slurm_additional_parameters={"nodelist": node}) return ex.submit(clock, partition, timeout_min) def wait_job_is_running(job: Job) -> None: while job.state in ("UNKNOWN", "PENDING"): log.info(f"{job} is not RUNNING") time.sleep(60) def preemption(): job = pascal_job("learnfair", timeout_min=2 * 60) log.info(f"Scheduled {job}, {job.paths.stdout}") # log.info(job.paths.submission_file.read_text()) wait_job_is_running(job) node = job.get_info()["NodeList"] log.info(f"{job} ({job.state}) is runnning on {node} !") # Schedule another pascal job on the same node, whith high priority priority_job = pascal_job("dev", timeout_min=15, node=node) log.info(f"Schedule {priority_job} ({job.state}) on {node} with high priority.") wait_job_is_running(priority_job) # if priority_job is running, then job should have been preempted learfair_stderr = job.stderr() assert learfair_stderr is not None, job.paths.stderr log.info( f"Job {priority_job} ({priority_job.state}) started, " f"job {job} ({job.state}) should have been preempted: {learfair_stderr}" ) interruptions = [l for l in learfair_stderr.splitlines() if "Interrupted" in l] assert len(interruptions) == 1, interruptions assert job.state in ("PENDING"), job.state interrupted_ts = interruptions[0].split("!!! Interrupted on: ")[-1] interrupted = datetime.fromisoformat(interrupted_ts) priority_job.result() print("Preemption test succeeded ✅") def main(): log.info("Hello !") if LOGS.exists(): log.info(f"Cleaning up log folder: {LOGS}") shutil.rmtree(str(LOGS)) preemption() if __name__ == "__main__": main()