submit_slurm_jobs.py (157 lines of code) (raw):
from enum import Enum
import os
from jinja2 import Template
import subprocess
import json
from typing import List
class Status(Enum):
# INIT -> PENDING -> [RUNNING | FAIL | TIMEOUT OOM] -> COMPLETED
INIT = "init" # Job is created
PENDING = "pending" # Job is waiting for ressources
RUNNING = "running" # Job is running
FAIL = "fail" # Job failed
OOM = "oom" # Job failed due to out of memory (expected behavior)
TIMEOUT = "timeout" # Job failed due to timeout
COMPLETED = "completed" # Job is completed
class Job:
def __init__(self, root_path: str, qos: str) -> None:
self.root_path = root_path
self.name = os.path.basename(root_path)
self.config = os.path.join(root_path, "config.json")
self.qos = qos
# Check if the status.txt file exists
status_file_path = os.path.join(self.root_path, "status.txt")
if not os.path.exists(status_file_path):
# Create the status.txt file with INIT status
with open(status_file_path, 'w') as f:
f.write(Status.INIT.value)
self.status = self.get_status()
def get_status(self) -> Status:
"""
Read the status of the job from `status.txt` and return it
"""
is_existing = lambda value_to_check: any(value.value == value_to_check for value in Status.__members__.values())
status_file_path = os.path.join(self.root_path, "status.txt")
with open(status_file_path, 'r') as f:
status = f.read()
if not is_existing(status):
raise ValueError("Invalid status")
return Status(status)
def set_status(self, status: Status) -> Status:
"""
Update the status of the job in `status.txt` and return the new status
"""
status_file_path = os.path.join(self.root_path, "status.txt")
with open(status_file_path, 'w') as f:
f.write(status.value)
return status
class Scheduler:
def __init__(self, inp_dir: str, qos: str) -> None:
jobs_directory_paths = [os.path.abspath(root) for root, dirs, _ in os.walk(inp_dir) if not dirs]
jobs_directory_paths = [job_path.replace("/profiler", "") if "profiler" in job_path else job_path for job_path in jobs_directory_paths]
self.job_lists = [Job(job_path, qos) for job_path in jobs_directory_paths]
def keep_only_jobs(self, status: Status):
return [job for job in self.job_lists if job.status == status]
def filter_out_jobs(self, status: Status):
return [job for job in self.job_lists if job.status != status]
def create_slurm_script(self, job: Job):
# Submit job to the cluster (edit jinja)
# load yaml config.yaml
with open(job.config, 'r') as file:
config = json.load(file)
max_gpu_per_node = 8
# Pick the right number of nodes and n_proc_per_node
world_size = config["distributed"]["tp_size"] * config["distributed"]["cp_size"] * config["distributed"]["pp_size"] * config["distributed"]["dp_size"]
assert world_size <= max_gpu_per_node or world_size % max_gpu_per_node == 0
nodes = max(1, world_size // max_gpu_per_node)
n_proc_per_node = min(max_gpu_per_node, world_size // nodes)
assert nodes * n_proc_per_node == world_size
context_bench = {
'nodes': nodes,
'n_proc_per_node': n_proc_per_node,
'root_path': job.root_path,
"config": job.config,
"qos": job.qos,
}
base_path = os.path.join(os.getcwd(), "template/base_job.slurm")
with open(base_path, 'r') as file:
base_job_file = file.read()
base_job_template = Template(base_job_file)
# Write the rendered script to a new file located at the job root_path
output_file_path = os.path.join(job.root_path, "job.slurm")
with open(output_file_path, 'w') as file:
file.write(base_job_template.render(context_bench))
print(f"Slurm script created at {output_file_path}")
def launch_dependency(self, job_array: List[Job], env_vars):
prev_job_id = None
for job in job_array:
if prev_job_id is None:
result = subprocess.run(["sbatch", '--parsable', os.path.join(job.root_path, "job.slurm")], env=env_vars, capture_output=True, text=True)
else:
result = subprocess.run(["sbatch", '--parsable', '--dependency=afterany:'+prev_job_id, os.path.join(job.root_path, "job.slurm")], env=env_vars, capture_output=True, text=True)
job.set_status(Status.PENDING)
prev_job_id = result.stdout.strip()
def check_status(self):
# find all status files using self.jobs_directory_paths
status_files = [os.path.join(job.root_path, "status.txt") for job in self.job_lists]
status_counts = {
"init": 0,
"pending": 0,
"running": 0,
"fail": 0,
"oom": 0,
"timeout": 0,
"completed": 0
}
for status_file in status_files:
with open(status_file, 'r') as f:
status = f.read().strip()
if status in status_counts:
status_counts[status] += 1
else:
raise ValueError(f"Invalid status: {status}")
total = sum(status_counts.values())
# Print the status counts in a formatted table
print(f"{'Status':<10} | {'Count':<6}")
print(f"{'-'*10}-|-{'-'*6}")
for status, count in status_counts.items():
print(f"{status.capitalize():<10} | {count:<6}")
print(f"{'-'*10}-|-{'-'*6}")
print(f"{'Total':<10} | {total:<6}")
def submit_jobs(inp_dir, qos, hf_token, nb_slurm_array, only: str = None):
scheduler = Scheduler(inp_dir, qos)
#TODO: batch into job arrays
env_vars = os.environ.copy()
env_vars["HUGGINGFACE_TOKEN"] = hf_token
total_jobs = len(scheduler.job_lists)
if only == "fail":
scheduler.job_lists = scheduler.keep_only_jobs(Status.FAIL)
elif only == "pending":
scheduler.job_lists = scheduler.keep_only_jobs(Status.PENDING)
elif only == "timeout":
scheduler.job_lists = scheduler.keep_only_jobs(Status.TIMEOUT)
elif only == "running":
scheduler.job_lists = scheduler.keep_only_jobs(Status.RUNNING)
if only is not None:
filtered_jobs = len(scheduler.job_lists)
if filtered_jobs == 0:
print(f"No '{only}' jobs to resubmit")
return
print(f"Only {filtered_jobs}/{total_jobs} jobs with status '{only}' will be resubmitted")
scheduler.job_lists = scheduler.filter_out_jobs(Status.COMPLETED)
if nb_slurm_array > 0:
# Use job dependecies
# Distribute the jobs into the arrays
base_jobs_per_array = len(scheduler.job_lists) // nb_slurm_array
extra_jobs = len(scheduler.job_lists) % nb_slurm_array
distribution = [base_jobs_per_array] * nb_slurm_array
for i in range(extra_jobs):
distribution[i] += 1
start = 0
for i, nb_jobs in enumerate(distribution):
previous_job_id = None
end = start + nb_jobs
job_array = scheduler.job_lists[start:end]
print(f"Launching job Dependency array {i+1} with {nb_jobs} jobs")
for job in job_array:
scheduler.create_slurm_script(job)
scheduler.launch_dependency(job_array, env_vars)
start = end
else:
# Don't use job dependecies
for job in scheduler.job_lists:
scheduler.create_slurm_script(job)
print(os.path.join(job.root_path, "job.slurm"))
subprocess.run(["sbatch", os.path.join(job.root_path, "job.slurm")], env=env_vars)
job.set_status(Status.PENDING)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Submit jobs to the cluster')
parser.add_argument('--inp_dir', type=str, help='Input directory containing the jobs')
parser.add_argument('--qos', type=str, help='QOS of the jobs')
parser.add_argument('--nb_slurm_array', type=int, default=0, help='Number of slurm arrays')
parser.add_argument('--only', type=str, default=None, help='Filter the jobs to submit')
parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token')
args = parser.parse_args()
#TODO: add more option like "python slurm.py submit_jobs --...." or "python slurm.py update_jobs --...." or "python slurm.py cancel_jobs --...." or "python slurm.py check_status --...."
submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only)