in gossip/gossiper.py [0:0]
def mix(self, out_msg, ps_weight, residual=False):
# out_msg must be on the correct device
assert out_msg.device.type == self.device.type
if self.logger is not None:
self.logger.debug('in/out -peers {}/{}'
.format(self.in_edges, self.out_edges))
# prepare messages for gossip
mixed_out_msgs = self.mix_out_msg_(out_msg, ps_weight, residual)
# send-recv w/ some code optimization to avoid buffer prep overhead
if len(self.in_edges) == 1 and len(self.out_edges) == 1 and residual:
out_edge, in_edge = self.out_edges[0], self.in_edges[0]
msg = next(mixed_out_msgs)
if not self.passive:
dist.broadcast(tensor=msg, src=out_edge.src,
group=out_edge.process_group)
dist.broadcast(tensor=self.in_msg_buffer, src=in_edge.src,
group=in_edge.process_group)
else:
dist.broadcast(tensor=self.in_msg_buffer, src=in_edge.src,
group=in_edge.process_group)
dist.broadcast(tensor=msg, src=out_edge.src,
group=out_edge.process_group)
# regular send-recv
else:
# prepare in-msg buffer
if not residual:
self.in_msg_buffer.copy_(next(mixed_out_msgs))
else:
self.in_msg_buffer.zero_()
# send-recv
for out_edge, in_edge in zip(self.out_edges, self.in_edges):
msg = next(mixed_out_msgs)
if not self.passive:
dist.broadcast(tensor=msg, src=out_edge.src,
group=out_edge.process_group)
dist.broadcast(tensor=self.placeholder, src=in_edge.src,
group=in_edge.process_group)
else:
dist.broadcast(tensor=self.placeholder, src=in_edge.src,
group=in_edge.process_group)
dist.broadcast(tensor=msg, src=out_edge.src,
group=out_edge.process_group)
self.in_msg_buffer.add_(self.placeholder)
self.refresh_peers_()
self.clean_msg_buffers_()
return self.parse_in_msg_buffer(residual)