in fairscale/experimental/nn/data_parallel/gossip/gossiper.py [0:0]
def mix(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# out_msg must be on the correct device
assert out_msg.device.type == self.device.type
if self.logger is not None:
self.logger.debug("in/out -peers {}/{}".format(self.in_edges, self.out_edges))
# prepare messages for gossip
mixed_out_msgs = self.mix_out_msg_(out_msg, ps_weight)
# send-recv w/ some code optimization to avoid buffer prep overhead
if len(self.in_edges) == 1 and len(self.out_edges) == 1:
out_edge, in_edge = self.out_edges[0], self.in_edges[0]
msg = next(mixed_out_msgs)
if not self.passive:
dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
dist.broadcast(
tensor=self.in_msg_buffer,
src=in_edge.src,
group=in_edge.process_group,
)
else:
dist.broadcast(
tensor=self.in_msg_buffer,
src=in_edge.src,
group=in_edge.process_group,
)
dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
# regular send-recv
else:
# prepare in-msg buffer
self.in_msg_buffer.zero_()
# send-recv
for out_edge, in_edge in zip(self.out_edges, self.in_edges):
msg = next(mixed_out_msgs)
if not self.passive:
dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
dist.broadcast(
tensor=self.placeholder,
src=in_edge.src,
group=in_edge.process_group,
)
else:
dist.broadcast(
tensor=self.placeholder,
src=in_edge.src,
group=in_edge.process_group,
)
dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
self.in_msg_buffer.add_(self.placeholder) # type: ignore
self.refresh_peers_()
self.clean_msg_buffers_()
return self.parse_in_msg_buffer()