in gossip/distributed.py [0:0]
def _query_gossip_queue(self, non_blocking=False):
""" Check gossip-queue for push-sum residuals and update model """
if not self.gossip_enable:
return
self.logger.debug('querying gossip queue')
# no gossip happening right now so just return
if not self.gossiping:
if self.process_rank % self.nprocs_per_node == 0:
self.logger.warning('not gossiping right now')
return False
if not non_blocking:
if not self.gossip_flag.wait(timeout=HEARTBEAT_TIMEOUT):
raise NameError('Gossip flag timeout')
sys.exit() # HEARTBEAT monitor
# query gossip thread
if self.gossip_flag.is_set():
self.logger.debug('received gossip flag')
# atomic gossip was interrupted so try again
if self.gossip_ps_weight[0] == -1:
self.gossip_flag.clear()
self.params_mixed = True
self.gossiping = False
self.transfer_params(mix=False)
return False
self.lazy_ps_factor.copy_(self.gossip_ps_factor)
# convert model-params to ps numerators b4 adding residuals
self.ps_numerator()
# add residuals
self.ps_weight += self.gossip_ps_weight
if self.lazy_mixing:
self.ps_weight *= self.lazy_ps_factor
for p, r in zip(self.module.parameters(),
self.gossip_device_buffer):
p.data.add_(r)
if self.lazy_mixing:
p.data.mul_(self.lazy_ps_factor.type(p.data.dtype))
# update flags
self.logger.debug('updated ps-weight {}'.format(self.ps_weight))
self.logger.debug('updated model params')
self.gossip_flag.clear()
self.params_mixed = True
self.gossiping = False
return True