in distributed_training/src_dir/dis_util.py [0:0]
def dist_setting(args):
# args.data_parallel = False
print(f"args.data_parallel : {args.data_parallel}, args.model_parallel : {args.model_parallel}, args.apex : {args.apex}")
args.world_size = 1
args.host_num = args.hosts.index(args.current_host)
if args.data_parallel:
sdp, DDP = _sdp_import(args)
args.world_size = sdp.get_world_size()
args.rank = sdp.get_rank() # total rank in all hosts
args.local_rank = sdp.get_local_rank() # rank per host
elif args.model_parallel:
args.world_size = smp.size()
args.world_size = args.num_gpus * len(args.hosts)
args.local_rank = smp.local_rank() # rank per host
args.rank = smp.rank()
args.dp_size = smp.dp_size()
args.dp_rank = smp.dp_rank()
else:
args.world_size = len(args.hosts) * args.num_gpus
if args.local_rank is not None:
args.rank = args.num_gpus * args.host_num + \
args.local_rank # total rank in all hosts
dist.init_process_group(backend=args.backend,
rank=args.rank,
world_size=args.world_size)
logger.info(
'Initialized the distributed environment: \'{}\' backend on {} nodes. '
.format(args.backend, dist.get_world_size()) +
'Current host rank is {}. Number of gpus: {}'.format(
dist.get_rank(), args.num_gpus))
# if not args.model_parallel:
args.lr = args.lr * float(args.world_size)
args.batch_size //= args.world_size // args.num_gpus
args.batch_size = max(args.batch_size, 1)
return args