slurm-to-batch/convert_to_batch_job.py (122 lines of code) (raw):

import os import json import yaml import sys from slurm_script_parser import SlurmJobConfig, SlurmScriptParser def generate_gres_conf_script(slurm_config: SlurmJobConfig): gres_conf_script = f"""#!/bin/bash # Script to configure Slurm's GPU resources in gres.conf cat <<EOF > /usr/local/etc/slurm/gres.conf # Define GPU resources AutoDetect=nvml """ if slurm_config.gpu_per_node > 0 and slurm_config.gpu_type not in ["None", None, ""]: for i in range(slurm_config.gpu_per_node): gres_conf_script += f"Name=gpu Type={slurm_config.gpu_type} File=/dev/nvidia{i}\n" else: for i in range(slurm_config.gpu_per_node): gres_conf_script += f"Name=gpu File=/dev/nvidia{i}\n" gres_conf_script += "EOF\n" return gres_conf_script def generate_slurm_conf_script(slurm_config: SlurmJobConfig) -> str: node_count = slurm_config.node_count slurm_conf_script_fixed = """ cat <<EOF > /usr/local/etc/slurm/slurm.conf ClusterName=${BATCH_JOB_ID} SlurmctldHost=$(head -1 ${BATCH_HOSTS_FILE}) AuthType=auth/munge ProctrackType=proctrack/pgid ReturnToService=2 # For GPU resource GresTypes=gpu SlurmctldPidFile=/var/run/slurm/slurmctld.pid SlurmdPidFile=/var/run/slurm/slurmd.pid # slurm logs SlurmdLogFile=/var/log/slurm/slurmd.log SlurmctldLogFile=/var/log/slurm/slurmctld.log SlurmdSpoolDir=/var/spool/slurmd SlurmUser=root StateSaveLocation=/var/spool/slurmctld TaskPlugin=task/none SchedulerType=sched/backfill SelectTypeParameters=CR_Core # Turn off both types of accounting JobAcctGatherFrequency=0 JobAcctGatherType=jobacct_gather/none AccountingStorageType=accounting_storage/none SlurmctldDebug=3 SlurmdDebug=3 SelectType=select/cons_tres """ slurm_conf_script_not_fixed = f"MaxNodeCount={node_count}\nPartitionName=all Nodes=ALL Default=yes\nEOF" return slurm_conf_script_fixed + slurm_conf_script_not_fixed def start_slurmctld() -> str: return f""" mkdir -p /var/spool/slurm chmod 755 /var/spool/slurm/ touch /var/log/slurmctld.log mkdir -p /var/log/slurm touch /var/log/slurm/slurmd.log /var/log/slurm/slurmctld.log touch /var/log/slurm_jobacct.log /var/log/slurm_jobcomp.log rm -rf /var/spool/slurmctld/* if [[ "$BATCH_NODE_INDEX" == "0" ]]; then systemctl start slurmctld MAX_RETRIES=5 RETRY_INTERVAL=5 for (( i=1; i<=MAX_RETRIES; i++ )); do if systemctl is-active --quiet slurmctld; then echo "slurmctld are running." break fi echo "Services not running. Retrying in $RETRY_INTERVAL seconds..." sleep $RETRY_INTERVAL done fi """ def start_slurmd(slurm_config:SlurmJobConfig) -> str: gpu_per_node = slurm_config.gpu_per_node return f"""#!/bin/bash /usr/local/sbin/slurmd -Z --conf "Gres=gpu:{gpu_per_node}" RETRIES=5 WAIT_TIME=1 for (( i=1; i<=$RETRIES; i++ )); do if ps -ef | grep -v grep | grep slurmd > /dev/null; then echo "slurmd is running!" exit 0 else echo "slurmd not found, retrying in $WAIT_TIME seconds..." sleep $WAIT_TIME fi done echo "slurmd did not start after $RETRIES attempts." exit 1 """ def work_load() -> str: return """#!/bin/bash if [[ "$BATCH_NODE_INDEX" == "0" ]]; then <SRUN_COMMAND> fi """ def createJobJSON(slurm_conf: SlurmJobConfig) -> dict: slurm_setup = generate_gres_conf_script(slurm_conf) + generate_slurm_conf_script(slurm_conf) + start_slurmctld() job_definition = { "taskGroups": [ { "task_spec": { "runnables": [ { "script": { "text": slurm_setup, }, }, { "barrier": { "name": "slurmctld-started" } }, { "script": { "text": start_slurmd(slurm_conf) } }, { "barrier": { "name": "slurmd-started-all-vms" } }, { "script": { "text": work_load() } }, { "barrier": { "name": "slurm-job-finished-all-vms" } }, ], }, "task_count_per_node": 1, "task_count": slurm_conf.node_count, "require_hosts_file": True } ], "allocation_policy": { "location": { "allowed_locations": ["zones/<SELECTED_ZONE>"] }, "instances": { "policy": { "accelerators": { "type": "<CUSTOM_GPU_TYPE>", "count": slurm_conf.gpu_per_node }, "boot_disk": { "image": "<CUSTOM_BOOT_IMAGE>", "size_gb": "<CUSTOM_BOOT_DISK_SIZE>", } }, "install_gpu_drivers": True } }, "labels": { "goog-batch-dynamic-workload-scheduler": "true" }, "logs_policy": { "destination": "CLOUD_LOGGING" } } return job_definition def str_presenter(dumper, data): """configures yaml for dumping multiline strings Ref: https://stackoverflow.com/questions/8640959/how-can-i-control-what-scalar-form-pyyaml-uses-for-my-data""" if data.count('\n') > 0: # check for multiline string return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') return dumper.represent_scalar('tag:yaml.org,2002:str', data) def main(): if len(sys.argv) != 2 and len(sys.argv) != 3: print( 'Usage: python3 convert_slurm_batch_job.py <slurm_script_path> <batch_template_folder>(optional)' ) sys.exit(1) slurm_script_path = sys.argv[1] if len(sys.argv) == 2: # Use the slurm_script_path's folder. output_dir = os.path.dirname(slurm_script_path) else: output_dir = sys.argv[2] os.makedirs(output_dir, exist_ok=True) config = SlurmScriptParser.parse_slurm_script(slurm_script_path) yaml.add_representer(str, str_presenter) yaml.representer.SafeRepresenter.add_representer(str, str_presenter) yaml_data = yaml.dump(createJobJSON(config), allow_unicode=True) yaml_file_path = os.path.join(output_dir, f"{config.job_name}_template.yaml") with open(yaml_file_path, 'w', encoding='utf-8') as file: file.write(yaml_data) if __name__ == '__main__': main()