def aggregate_message()

in gala/gpu_gossip_buffer.py [0:0]


    def aggregate_message(self, rank, model):
        """
        Average messages with local model:
        Average all in-neighbours' (defined in 'self.topology') parameters with
        agent 'rank's 'model' and copy the result into 'model'.

        Agents should only aggregate messages into their own model:
            i.e., ensure 'model belongs to agent 'rank'
        """
        with torch.no_grad():

            # Check if in-peers finished writing messages to broadcast buffers
            _, in_peers = self.topology[rank].get_peers()
            write_complete = True
            for peer in in_peers:
                peer_buffer = self.msg_buffer[peer]
                write_event = peer_buffer[2][rank]
                if not write_event.is_set():
                    write_complete = False
                    break

            # Check if any messages are excessively stale
            stale_assert = self.sync_list[rank] >= self.sync_freq

            # If peers done writing or message too stale, wait and clear events
            if write_complete or stale_assert:
                for peer in in_peers:
                    peer_buffer = self.msg_buffer[peer]
                    write_event = peer_buffer[2][rank]
                    write_event.wait()
                    write_event.clear()
                    self.sync_list[rank] = 0
            # Not done writing, but staleness is still tolerable
            else:
                self.sync_list[rank] += 1
                print('%s: staleness %s' % (rank, self.sync_list[rank]))
                return

            # Lazy-mixing of local params
            num_peers = self.topology[rank].peers_per_itr
            for p in model.parameters():
                p.data.div_(num_peers + 1)

            # Aggregate received messages
            for peer in in_peers:
                peer_buffer = self.msg_buffer[peer]
                lock = peer_buffer[3]
                with lock:
                    # Read message and update 'params'
                    peer_msg = peer_buffer[0]
                    for p, bp in zip(model.parameters(),
                                     peer_msg.parameters()):
                        p.data.add_(bp.to(p.device, non_blocking=True))
                    torch.cuda.current_stream().synchronize()
                    # Mark message as 'read'
                    peer_buffer[1][rank].set()