in gossip_sgd.py [0:0]
def parse_args():
"""
Set env-vars and global args
rank: <-- $SLRUM_PROCID
world_size<-- $SLURM_NTASKS
Master address <-- $SLRUM_NODENAME of rank 0 process (or HOSTNAME)
Master port <-- any free port (doesn't really matter)
"""
args = parser.parse_args()
ClusterManager.set_checkpoint_dir(args.checkpoint_dir)
# rank and world_size need to be changed depending on the scheduler being
# used to run the distributed jobs
args.master_addr = os.environ['HOSTNAME']
if args.backend == 'mpi':
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_UNIVERSE_SIZE'])
else:
args.rank = int(os.environ['SLURM_PROCID'])
args.world_size = int(os.environ['SLURM_NTASKS'])
args.out_fname = ClusterManager.CHECKPOINT_DIR \
+ args.tag \
+ 'out_r' + str(args.rank) \
+ '_n' + str(args.world_size) \
+ '.csv'
args.resume = True if args.resume == 'True' else False
args.verbose = True if args.verbose == 'True' else False
args.train_fast = True if args.train_fast == 'True' else False
args.nesterov = True if args.nesterov == 'True' else False
args.checkpoint_all = True if args.checkpoint_all == 'True' else False
args.warmup = True if args.warmup == 'True' else False
args.overlap = True if args.overlap == 'True' else False
args.push_sum = True if args.push_sum == 'True' else False
args.all_reduce = True if args.all_reduce == 'True' else False
args.cpu_comm = True if (args.backend == 'gloo' and not args.push_sum and
not args.all_reduce) else False
args.comm_device = torch.device('cpu') if args.cpu_comm else torch.device('cuda')
args.overwrite_checkpoints = True if args.overwrite_checkpoints == 'True' else False
args.lr_schedule = {}
if args.schedule is None:
args.schedule = [30, 0.1, 60, 0.1, 80, 0.1]
i, epoch = 0, None
for v in args.schedule:
if i == 0:
epoch = v
elif i == 1:
args.lr_schedule[epoch] = v
i = (i + 1) % 2
del args.schedule
# parse peers per itr sched (epoch, num_peers)
args.ppi_schedule = {}
if args.peers_per_itr_schedule is None:
args.peers_per_itr_schedule = [0, 1]
i, epoch = 0, None
for v in args.peers_per_itr_schedule:
if i == 0:
epoch = v
elif i == 1:
args.ppi_schedule[epoch] = v
i = (i + 1) % 2
del args.peers_per_itr_schedule
# must specify how many peers to communicate from the start of training
assert 0 in args.ppi_schedule
if args.all_reduce:
assert args.graph_type == -1
if args.backend == 'gloo':
assert args.network_interface_type == 'ethernet'
os.environ['GLOO_SOCKET_IFNAME'] = get_tcp_interface_name(
network_interface_type=args.network_interface_type
)
elif args.network_interface_type == 'ethernet':
if args.backend == 'nccl':
os.environ['NCCL_SOCKET_IFNAME'] = get_tcp_interface_name(
network_interface_type=args.network_interface_type
)
os.environ['NCCL_IB_DISABLE'] = '1'
else:
raise NotImplementedError
# initialize torch distributed backend
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
dist.init_process_group(backend=args.backend,
world_size=args.world_size,
rank=args.rank)
args.graph, args.mixing = None, None
graph_class = GRAPH_TOPOLOGIES[args.graph_type]
if graph_class:
# dist.barrier is done here to ensure the NCCL communicator is created
# here. This prevents an error which may be caused if the NCCL
# communicator is created at a time gap of more than 5 minutes in
# different processes
dist.barrier()
args.graph = graph_class(
args.rank, args.world_size, peers_per_itr=args.ppi_schedule[0])
mixing_class = MIXING_STRATEGIES[args.mixing_strategy]
if mixing_class and args.graph:
args.mixing = mixing_class(args.graph, args.comm_device)
return args