in gossip_sgd_adpsgd.py [0:0]
def parse_args():
"""
Set env-vars and global args
rank: <-- $SLRUM_PROCID
world_size<-- $SLURM_NTASKS
Master address <-- $SLRUM_NODENAME of rank 0 process (or HOSTNAME)
Master port <-- any free port (doesn't really matter)
"""
args = parser.parse_args()
ClusterManager.set_checkpoint_dir(args.checkpoint_dir)
args.master_addr = os.environ['HOSTNAME']
if args.backend == 'mpi':
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_UNIVERSE_SIZE'])
args.device_id = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
else:
args.rank = int(os.environ['SLURM_PROCID'])
args.world_size = int(os.environ['SLURM_NTASKS'])
args.device_id = int(os.environ['SLURM_LOCALID'])
args.out_fname = ClusterManager.CHECKPOINT_DIR \
+ args.tag \
+ 'out_r' + str(args.rank) \
+ '_n' + str(args.world_size) \
+ '.csv'
args.resume = True if args.resume == 'True' else False
args.verbose = True if args.verbose == 'True' else False
args.train_fast = True if args.train_fast == 'True' else False
args.nesterov = True if args.nesterov == 'True' else False
args.checkpoint_all = True if args.checkpoint_all == 'True' else False
args.warmup = True if args.warmup == 'True' else False
args.cpu_comm = True if args.backend == 'gloo' else False
args.comm_device = torch.device('cpu') if args.cpu_comm else torch.device('cuda')
args.overlap = True if args.overlap == 'True' else False
args.push_sum = True if args.push_sum == 'True' else False
args.all_reduce = True if args.all_reduce == 'True' else False
args.bilat = True if args.bilat == 'True' else False
args.global_epoch = None
args.global_itr = None
if args.rank == 0 and os.path.isfile(args.shared_fpath):
os.remove(args.shared_fpath)
while os.path.isfile(args.shared_fpath):
pass
args.lr_schedule = {}
if args.schedule is None:
args.schedule = [30, 0.1, 60, 0.1, 80, 0.1]
i, epoch = 0, None
for v in args.schedule:
if i == 0:
epoch = v
elif i == 1:
args.lr_schedule[epoch] = v
i = (i + 1) % 2
del args.schedule
# parse peers per itr sched (epoch, num_peers)
args.ppi_schedule = {}
if args.peers_per_itr_schedule is None:
args.peers_per_itr_schedule = [0, 1]
i, epoch = 0, None
for v in args.peers_per_itr_schedule:
if i == 0:
epoch = v
elif i == 1:
args.ppi_schedule[epoch] = v
i = (i + 1) % 2
del args.peers_per_itr_schedule
# must specify how many peers to communicate from the start of training
assert 0 in args.ppi_schedule
if args.backend == 'gloo':
assert args.network_interface_type == 'ethernet'
os.environ['GLOO_SOCKET_IFNAME'] = get_tcp_interface_name(
network_interface_type=args.network_interface_type
)
elif args.network_interface_type == 'ethernet':
if args.backend == 'nccl':
os.environ['NCCL_SOCKET_IFNAME'] = get_tcp_interface_name(
network_interface_type=args.network_interface_type
)
os.environ['NCCL_IB_DISABLE'] = '1'
else:
raise NotImplementedError
# initialize torch distributed backend
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = str(int(args.master_port) + 1)
dist.init_process_group(backend=args.backend,
world_size=args.world_size,
rank=args.rank)
args.graph_class = GRAPH_TOPOLOGIES[args.graph_type]
args.mixing_class = MIXING_STRATEGIES[args.mixing_strategy]
if args.graph_class is None:
raise Exception('Incorrect arguments for graph_type')
if args.mixing_class is None:
raise Exception('Incorrect arguments for mixing_strategy')
return args