def dist_setting()

in distributed_training/src_dir/dis_util.py [0:0]


def dist_setting(args):
    #     args.data_parallel = False
    print(f"args.data_parallel : {args.data_parallel}, args.model_parallel : {args.model_parallel}, args.apex : {args.apex}")


    args.world_size = 1
    args.host_num = args.hosts.index(args.current_host)

    if args.data_parallel:
        sdp, DDP = _sdp_import(args)
        
        args.world_size = sdp.get_world_size()
        args.rank = sdp.get_rank()  # total rank in all hosts
        args.local_rank = sdp.get_local_rank()  # rank per host
    elif args.model_parallel:
        args.world_size = smp.size()
        args.world_size = args.num_gpus * len(args.hosts)
        args.local_rank = smp.local_rank()  # rank per host
        args.rank = smp.rank()
        args.dp_size = smp.dp_size()
        args.dp_rank = smp.dp_rank()
    else:
        args.world_size = len(args.hosts) * args.num_gpus
        if args.local_rank is not None:
            args.rank = args.num_gpus * args.host_num + \
                args.local_rank  # total rank in all hosts

        dist.init_process_group(backend=args.backend,
                                rank=args.rank,
                                world_size=args.world_size)
        logger.info(
            'Initialized the distributed environment: \'{}\' backend on {} nodes. '
            .format(args.backend, dist.get_world_size()) +
            'Current host rank is {}. Number of gpus: {}'.format(
                dist.get_rank(), args.num_gpus))
    
#     if not args.model_parallel:
    args.lr = args.lr * float(args.world_size)
    args.batch_size //= args.world_size // args.num_gpus
    args.batch_size = max(args.batch_size, 1)

    return args