def parse_args()

in gossip_sgd_adpsgd.py [0:0]


def parse_args():
    """
    Set env-vars and global args
        rank: <-- $SLRUM_PROCID
        world_size<-- $SLURM_NTASKS
        Master address <-- $SLRUM_NODENAME of rank 0 process (or HOSTNAME)
        Master port <-- any free port (doesn't really matter)
    """
    args = parser.parse_args()
    ClusterManager.set_checkpoint_dir(args.checkpoint_dir)

    args.master_addr = os.environ['HOSTNAME']
    if args.backend == 'mpi':
        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        args.world_size = int(os.environ['OMPI_UNIVERSE_SIZE'])
        args.device_id = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
    else:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.world_size = int(os.environ['SLURM_NTASKS'])
        args.device_id = int(os.environ['SLURM_LOCALID'])

    args.out_fname = ClusterManager.CHECKPOINT_DIR \
        + args.tag \
        + 'out_r' + str(args.rank) \
        + '_n' + str(args.world_size) \
        + '.csv'
    args.resume = True if args.resume == 'True' else False
    args.verbose = True if args.verbose == 'True' else False
    args.train_fast = True if args.train_fast == 'True' else False
    args.nesterov = True if args.nesterov == 'True' else False
    args.checkpoint_all = True if args.checkpoint_all == 'True' else False
    args.warmup = True if args.warmup == 'True' else False
    args.cpu_comm = True if args.backend == 'gloo' else False
    args.comm_device = torch.device('cpu') if args.cpu_comm else torch.device('cuda')
    args.overlap = True if args.overlap == 'True' else False
    args.push_sum = True if args.push_sum == 'True' else False
    args.all_reduce = True if args.all_reduce == 'True' else False
    args.bilat = True if args.bilat == 'True' else False
    args.global_epoch = None
    args.global_itr = None
    if args.rank == 0 and os.path.isfile(args.shared_fpath):
        os.remove(args.shared_fpath)
    while os.path.isfile(args.shared_fpath):
        pass
    args.lr_schedule = {}
    if args.schedule is None:
        args.schedule = [30, 0.1, 60, 0.1, 80, 0.1]
    i, epoch = 0, None
    for v in args.schedule:
        if i == 0:
            epoch = v
        elif i == 1:
            args.lr_schedule[epoch] = v
        i = (i + 1) % 2
    del args.schedule

    # parse peers per itr sched (epoch, num_peers)
    args.ppi_schedule = {}
    if args.peers_per_itr_schedule is None:
        args.peers_per_itr_schedule = [0, 1]
    i, epoch = 0, None
    for v in args.peers_per_itr_schedule:
        if i == 0:
            epoch = v
        elif i == 1:
            args.ppi_schedule[epoch] = v
        i = (i + 1) % 2
    del args.peers_per_itr_schedule
    # must specify how many peers to communicate from the start of training
    assert 0 in args.ppi_schedule

    if args.backend == 'gloo':
        assert args.network_interface_type == 'ethernet'
        os.environ['GLOO_SOCKET_IFNAME'] = get_tcp_interface_name(
            network_interface_type=args.network_interface_type
        )
    elif args.network_interface_type == 'ethernet':
        if args.backend == 'nccl':
            os.environ['NCCL_SOCKET_IFNAME'] = get_tcp_interface_name(
                network_interface_type=args.network_interface_type
            )
            os.environ['NCCL_IB_DISABLE'] = '1'
        else:
            raise NotImplementedError

    # initialize torch distributed backend
    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = str(int(args.master_port) + 1)
    dist.init_process_group(backend=args.backend,
                            world_size=args.world_size,
                            rank=args.rank)

    args.graph_class = GRAPH_TOPOLOGIES[args.graph_type]
    args.mixing_class = MIXING_STRATEGIES[args.mixing_strategy]

    if args.graph_class is None:
        raise Exception('Incorrect arguments for graph_type')
    if args.mixing_class is None:
        raise Exception('Incorrect arguments for mixing_strategy')

    return args