in torchbenchmark/models/demucs/run_slurm.py [0:0]
def main():
args = sys.argv[1:]
gpus = th.cuda.device_count()
n_nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
node_id = int(os.environ['SLURM_NODEID'])
job_id = int(os.environ['SLURM_JOBID'])
rank_offset = gpus * node_id
hostnames = sp.run(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']],
capture_output=True,
check=True).stdout
master_addr = hostnames.split()[0].decode('utf-8')
if n_nodes == 1:
port = free_port()
else:
port = 20_000 + (job_id % 40_000)
args += ["--world_size", str(n_nodes * gpus), "--master", f"{master_addr}:{port}"]
tasks = []
print("About to go live", master_addr, node_id, n_nodes, file=sys.stderr)
sys.stderr.flush()
for gpu in range(gpus):
kwargs = {}
if gpu > 0:
kwargs['stdin'] = sp.DEVNULL
kwargs['stdout'] = sp.DEVNULL
# We keep stderr to see tracebacks from children.
tasks.append(
sp.Popen(["python3", "-m", "demucs"] + args +
["--rank", str(rank_offset + gpu)], **kwargs))
tasks[-1].rank = rank_offset + gpu
failed = False
try:
while tasks:
for task in tasks:
try:
exitcode = task.wait(0.1)
except sp.TimeoutExpired:
continue
else:
tasks.remove(task)
if exitcode:
print(f"Task {task.rank} died with exit code "
f"{exitcode}",
file=sys.stderr)
failed = True
else:
print(f"Task {task.rank} exited successfully")
if failed:
break
time.sleep(1)
except KeyboardInterrupt:
for task in tasks:
task.terminate()
raise
if failed:
for task in tasks:
task.terminate()
sp.run(["scancel", str(job_id)], check=True)
sys.exit(1)