in main.py [0:0]
def main():
args = parser.parse_args()
args.ngpus_per_node = torch.cuda.device_count()
if 'SLURM_JOB_ID' in os.environ:
# single-node and multi-node distributed training on SLURM cluster
# requeue job on SLURM preemption
signal.signal(signal.SIGUSR1, handle_sigusr1)
signal.signal(signal.SIGTERM, handle_sigterm)
# find a common host name on all nodes
# assume scontrol returns hosts in the same order on all nodes
cmd = 'scontrol show hostnames ' + os.getenv('SLURM_JOB_NODELIST')
stdout = subprocess.check_output(cmd.split())
host_name = stdout.decode().splitlines()[0]
args.rank = int(os.getenv('SLURM_NODEID')) * args.ngpus_per_node
args.world_size = int(os.getenv('SLURM_NNODES')) * args.ngpus_per_node
args.dist_url = f'tcp://{host_name}:58472'
else:
# single-node distributed training
args.rank = 0
args.dist_url = 'tcp://localhost:58472'
args.world_size = args.ngpus_per_node
torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)