def run_communication()

in src/nanotron/parallel/pipeline_parallel/state.py [0:0]


    def run_communication(self, send_only_activation: bool = False):
        # four cases:
        #  - you receive from higher rank and you send to higher rank
        #  - You receive from higher rank and you send to lower rank
        #  - you receive from lower rank and you send to higher rank
        #  - you receive from lower rank and you send to lower rank

        send_activation = None
        # Pop all send activation
        for _ in range(min(1, len(self.microbatches_activations_to_send))):
            send_activation = self.microbatches_activations_to_send.popleft()

        # Pop all recv activation
        recv_activation = None
        for _ in range(min(1, len(self.microbatches_activations_to_recv))):
            recv_activation = self.microbatches_activations_to_recv.popleft()

        if send_activation is None:
            if recv_activation is None:
                raise ValueError("Why the hell do we communicate when there's nothing to communicate?")
            self.activations_buffer.append(recv_activation())
        else:
            if recv_activation is None:
                send_activation()
            else:
                # Define in which order to we do it.
                # Actually we can't do any heuristics as you need global information in order to define clear ordering.
                # We make a BIG assumption that only ONE rank receives from higher rank and sends to higher rank.
                # In this case we find the "lowest" rank, send first
                # All the other ranks receive first and send after
                # Lowest rank receives.
                # If we knew who was involved in the cycle, we could just randomly choose one rank to first send then recv, however it's not clear who's involved
                p2p = send_activation.p2p
                assert p2p == recv_activation.p2p
                is_lowest = send_activation.to_rank > dist.get_rank(
                    p2p.pg
                ) and recv_activation.from_rank > dist.get_rank(p2p.pg)
                if is_lowest:
                    send_activation()
                    self.activations_buffer.append(recv_activation())
                else:
                    self.activations_buffer.append(recv_activation())
                    send_activation()