def mix()

in fairscale/experimental/nn/data_parallel/gossip/gossiper.py [0:0]


    def mix(self, out_msg: torch.Tensor, ps_weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # out_msg must be on the correct device
        assert out_msg.device.type == self.device.type
        if self.logger is not None:
            self.logger.debug("in/out -peers {}/{}".format(self.in_edges, self.out_edges))

        # prepare messages for gossip
        mixed_out_msgs = self.mix_out_msg_(out_msg, ps_weight)

        # send-recv w/ some code optimization to avoid buffer prep overhead
        if len(self.in_edges) == 1 and len(self.out_edges) == 1:
            out_edge, in_edge = self.out_edges[0], self.in_edges[0]
            msg = next(mixed_out_msgs)
            if not self.passive:
                dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
                dist.broadcast(
                    tensor=self.in_msg_buffer,
                    src=in_edge.src,
                    group=in_edge.process_group,
                )
            else:
                dist.broadcast(
                    tensor=self.in_msg_buffer,
                    src=in_edge.src,
                    group=in_edge.process_group,
                )
                dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)

        # regular send-recv
        else:
            # prepare in-msg buffer
            self.in_msg_buffer.zero_()

            # send-recv
            for out_edge, in_edge in zip(self.out_edges, self.in_edges):
                msg = next(mixed_out_msgs)
                if not self.passive:
                    dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
                    dist.broadcast(
                        tensor=self.placeholder,
                        src=in_edge.src,
                        group=in_edge.process_group,
                    )
                else:
                    dist.broadcast(
                        tensor=self.placeholder,
                        src=in_edge.src,
                        group=in_edge.process_group,
                    )
                    dist.broadcast(tensor=msg, src=out_edge.src, group=out_edge.process_group)
                self.in_msg_buffer.add_(self.placeholder)  # type: ignore

        self.refresh_peers_()
        self.clean_msg_buffers_()
        return self.parse_in_msg_buffer()