in gossip/ad_psgd.py [0:0]
def _gossip_target(dist_config, gossip_enable_flag, train_write_flag,
gossip_read_flag, gossip_update_flag, gossip_lr,
gossip_lock, gossip_queue, tcp_interface_name):
""" Gossip thread, which performs push-sum on model params """
with torch.no_grad():
gossip_params, gossip_grads = gossip_queue.get()
# prepare gossip process control objects
gossip_optimizer = torch.optim.SGD(
gossip_params,
lr=dist_config['lr'],
momentum=dist_config['momentum'],
weight_decay=dist_config['weight_decay'],
nesterov=dist_config['nesterov'])
if dist_config['backend'] == 'gloo':
assert dist_config['network_interface_type'] == 'ethernet'
elif dist_config['network_interface_type'] == 'ethernet':
if dist_config['backend'] == 'nccl':
os.environ['NCCL_SOCKET_IFNAME'] = tcp_interface_name
os.environ['NCCL_IB_DISABLE'] = '1'
elif dist_config['backend'] == 'gloo':
os.environ['GLOO_SOCKET_IFNAME'] = tcp_interface_name
else:
raise NotImplementedError
# initialize torch distributed backend
os.environ['MASTER_ADDR'] = dist_config['master_addr']
os.environ['MASTER_PORT'] = dist_config['master_port']
dist.init_process_group(backend=dist_config['backend'],
world_size=dist_config['world_size'],
rank=dist_config['rank'])
logger = make_logger(dist.get_rank(), dist_config['verbose'])
logger.debug('init rcvd: gossip_params {}, gossip_grads {}'.format(
gossip_params[0].norm(), gossip_grads[0].norm()))
# init gossip instance
graph_class = dist_config['graph_class']
mixing_class = dist_config['mixing_class']
if graph_class:
# dist.barrier is done here to ensure the NCCL communicator is
# created here. This prevents an error which may be caused if
# the NCCL # communicator is created at a time gap of more
# than 5 minutes in different processes
dist.barrier()
graph = graph_class(
dist_config['rank'], dist_config['world_size'],
peers_per_itr=dist_config['num_peers'])
if mixing_class and graph:
mixing = mixing_class(graph, dist_config['comm_device'])
gossiper = BilatPushPull(flatten_tensors(gossip_params),
graph=graph,
mixing=mixing,
logger=logger)
dist_config['graph'] = gossiper._graph_manager
dist_config['mixing'] = gossiper._mixing_manager
dist_config['gossiper'] = gossiper
model_meter = Meter(ptag='Model', stateful=True, csv_format=False)
gossip_meter = Meter(
ptag='Gossip', stateful=True, csv_format=False)
gossip_read_flag.set()
# gossip loop
while True:
# we may be asked to hold off on gossip for some time
gossip_enable_flag.wait()
# we may be notified to update our learning rate
if gossip_update_flag.is_set():
for pg in gossip_optimizer.param_groups:
pg['lr'] = gossip_lr.value
logger.debug('updated lr to {}'.format(gossip_lr.value))
gossip_update_flag.clear()
# train process is telling us it computed the new grads
if train_write_flag.is_set():
bt = time.time()
with gossip_lock:
i = 0
for p in gossip_params:
if p.requires_grad:
p.grad = gossip_grads[i]
i += 1
gossip_optimizer.step()
gossip_optimizer.zero_grad()
train_write_flag.clear()
gossip_read_flag.set()
model_meter.update(time.time() - bt)
logger.debug(model_meter)
try:
# construct gossip tensor
bt = time.time()
with gossip_lock:
out_msg = flatten_tensors(gossip_params).to(
dist_config['comm_device'])
# gossip step
in_msg, completed = gossiper.mix(out_msg)
# update gossip params (local model)
if completed:
with gossip_lock:
for p, g in zip(
gossip_params, unflatten_tensors(
in_msg, gossip_params)):
p.data.add_(g.to(p.device)).mul_(0.5)
gossip_meter.update(time.time() - bt)
logger.debug(gossip_meter)
except RuntimeError as e:
logger.warning('received runtime error {}'.format(e))
gossiper.clean_msg_buffers_()