def write_message()

in gala/gpu_gossip_buffer.py [0:0]


    def write_message(self, rank, model, rotate=False):
        """
        Write agent 'rank's 'model' to a local 'boradcast buffer' that will be
        read by the out-neighbours defined in 'self.topology'.

        :param rank: Agent's rank in multi-agent graph topology
        :param model: Agent's torch neural network model
        :param rotate: Whether to alternate peers in graph topology

        Agents should only write to their own broadcast buffer:
            i.e., ensure 'model' belongs to agent 'rank'
        WARNING: setting rotate=True with sync_freq > 1 not currently supported
        """
        with torch.no_grad():

            # Get local broadcast-buffer
            msg_buffer = self.msg_buffer[rank]
            broadcast_buffer = msg_buffer[0]
            read_event_list = msg_buffer[1]
            write_event_list = msg_buffer[2]
            lock = msg_buffer[3]

            # Check if out-peers finished reading our last message
            out_peers, _ = self.topology[rank].get_peers()
            read_complete = True
            for peer in out_peers:
                if not read_event_list[peer].is_set():
                    read_complete = False
                    break

            # If peers done reading our last message, wait and clear events
            if read_complete:
                for peer in out_peers:
                    read_event = read_event_list[peer]
                    read_event.wait()
                    read_event.clear()
            # If not done reading, cannot write another message right now
            else:
                return

            # Update broadcast-buffer with new message
            # -- flatten params and multiply by mixing-weight
            num_peers = self.topology[rank].peers_per_itr
            with lock:
                for bp, p in zip(broadcast_buffer.parameters(),
                                 model.parameters()):
                    bp.data.copy_(p)
                    bp.data.div_(num_peers + 1)
                # -- mark message as 'written'
                out_peers, _ = self.topology[rank].get_peers(rotate)
                torch.cuda.current_stream().synchronize()
                for peer in out_peers:
                    write_event_list[peer].set()