in gossip/distributed.py [0:0]
def _gossip_target(dist_config, gossip_flag, train_flag, gossip_lock,
gossip_params, gossip_device_buffer,
gossip_ps_weight, gossip_ps_factor, gossip_stream):
""" Gossip thread, which performs push-sum on model params """
logger = make_logger(dist_config['rank'], dist_config['verbose'])
gossip_params_by_dtype = group_by_dtype(gossip_params)
gossip_device_buffer_by_dtype = group_by_dtype(gossip_device_buffer)
gossipers = {}
# init gossip instance
gossiper_class = PushSum if dist_config['push_sum'] else PushPull
for dtype in gossip_params_by_dtype:
gossipers[dtype] = gossiper_class(
flatten_tensors(gossip_params_by_dtype[dtype]),
device=dist_config['comm_device'],
graph=dist_config['graph'],
mixing=dist_config['mixing'],
rank=dist_config['process_rank'],
world_size=dist_config['world_size'],
logger=logger)
dist_config['gossipers'] = gossipers
gossip_ps_factor.data.copy_(
gossipers[list(gossipers)[0]].mixing_weights['lo'])
gossip_flag.set()
# gossip loop
while True:
train_flag.wait()
logger.debug('received train-flag')
try:
with torch.cuda.stream(gossip_stream):
for dtype in gossip_params_by_dtype:
ps_weight, ps_factor = GossipDataParallel._gossip_into_receive_buffer(
gossip_params_by_dtype[dtype], gossipers[dtype],
gossip_device_buffer_by_dtype[dtype],
gossip_ps_weight, gossip_lock, dist_config)
gossip_ps_weight.copy_(ps_weight)
gossip_ps_factor.copy_(ps_factor)
except RuntimeError as e:
logger.warning('received runtime error {}'.format(e))
for gossiper in gossipers.values():
gossiper.clean_msg_buffers_()
gossip_ps_weight.fill_(-1)
finally:
# Make sure all queued operations are complete
gossip_stream.synchronize()
# give main thread go-ahead to read our gossip buffer
train_flag.clear()
gossip_flag.set()