def parse_args()

in gossip_sgd.py [0:0]


def parse_args():
    """
    Set env-vars and global args
        rank: <-- $SLRUM_PROCID
        world_size<-- $SLURM_NTASKS
        Master address <-- $SLRUM_NODENAME of rank 0 process (or HOSTNAME)
        Master port <-- any free port (doesn't really matter)
    """
    args = parser.parse_args()
    ClusterManager.set_checkpoint_dir(args.checkpoint_dir)

    # rank and world_size need to be changed depending on the scheduler being
    # used to run the distributed jobs
    args.master_addr = os.environ['HOSTNAME']
    if args.backend == 'mpi':
        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        args.world_size = int(os.environ['OMPI_UNIVERSE_SIZE'])
    else:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.world_size = int(os.environ['SLURM_NTASKS'])
    args.out_fname = ClusterManager.CHECKPOINT_DIR \
        + args.tag \
        + 'out_r' + str(args.rank) \
        + '_n' + str(args.world_size) \
        + '.csv'
    args.resume = True if args.resume == 'True' else False
    args.verbose = True if args.verbose == 'True' else False
    args.train_fast = True if args.train_fast == 'True' else False
    args.nesterov = True if args.nesterov == 'True' else False
    args.checkpoint_all = True if args.checkpoint_all == 'True' else False
    args.warmup = True if args.warmup == 'True' else False
    args.overlap = True if args.overlap == 'True' else False
    args.push_sum = True if args.push_sum == 'True' else False
    args.all_reduce = True if args.all_reduce == 'True' else False
    args.cpu_comm = True if (args.backend == 'gloo' and not args.push_sum and
                             not args.all_reduce) else False
    args.comm_device = torch.device('cpu') if args.cpu_comm else torch.device('cuda')
    args.overwrite_checkpoints = True if args.overwrite_checkpoints == 'True' else False
    args.lr_schedule = {}
    if args.schedule is None:
        args.schedule = [30, 0.1, 60, 0.1, 80, 0.1]
    i, epoch = 0, None
    for v in args.schedule:
        if i == 0:
            epoch = v
        elif i == 1:
            args.lr_schedule[epoch] = v
        i = (i + 1) % 2
    del args.schedule

    # parse peers per itr sched (epoch, num_peers)
    args.ppi_schedule = {}
    if args.peers_per_itr_schedule is None:
        args.peers_per_itr_schedule = [0, 1]
    i, epoch = 0, None
    for v in args.peers_per_itr_schedule:
        if i == 0:
            epoch = v
        elif i == 1:
            args.ppi_schedule[epoch] = v
        i = (i + 1) % 2
    del args.peers_per_itr_schedule
    # must specify how many peers to communicate from the start of training
    assert 0 in args.ppi_schedule

    if args.all_reduce:
        assert args.graph_type == -1

    if args.backend == 'gloo':
        assert args.network_interface_type == 'ethernet'
        os.environ['GLOO_SOCKET_IFNAME'] = get_tcp_interface_name(
            network_interface_type=args.network_interface_type
        )
    elif args.network_interface_type == 'ethernet':
        if args.backend == 'nccl':
            os.environ['NCCL_SOCKET_IFNAME'] = get_tcp_interface_name(
                network_interface_type=args.network_interface_type
            )
            os.environ['NCCL_IB_DISABLE'] = '1'
        else:
            raise NotImplementedError

    # initialize torch distributed backend
    os.environ['MASTER_ADDR'] = args.master_addr
    os.environ['MASTER_PORT'] = args.master_port
    dist.init_process_group(backend=args.backend,
                            world_size=args.world_size,
                            rank=args.rank)

    args.graph, args.mixing = None, None
    graph_class = GRAPH_TOPOLOGIES[args.graph_type]
    if graph_class:
        # dist.barrier is done here to ensure the NCCL communicator is created
        # here. This prevents an error which may be caused if the NCCL
        # communicator is created at a time gap of more than 5 minutes in
        # different processes
        dist.barrier()
        args.graph = graph_class(
            args.rank, args.world_size, peers_per_itr=args.ppi_schedule[0])

    mixing_class = MIXING_STRATEGIES[args.mixing_strategy]
    if mixing_class and args.graph:
        args.mixing = mixing_class(args.graph, args.comm_device)

    return args