in src/nanotron/parallel/pipeline_parallel/state.py [0:0]
def run_communication(self, send_only_activation: bool = False):
# four cases:
# - you receive from higher rank and you send to higher rank
# - You receive from higher rank and you send to lower rank
# - you receive from lower rank and you send to higher rank
# - you receive from lower rank and you send to lower rank
send_activation = None
# Pop all send activation
for _ in range(min(1, len(self.microbatches_activations_to_send))):
send_activation = self.microbatches_activations_to_send.popleft()
# Pop all recv activation
recv_activation = None
for _ in range(min(1, len(self.microbatches_activations_to_recv))):
recv_activation = self.microbatches_activations_to_recv.popleft()
if send_activation is None:
if recv_activation is None:
raise ValueError("Why the hell do we communicate when there's nothing to communicate?")
self.activations_buffer.append(recv_activation())
else:
if recv_activation is None:
send_activation()
else:
# Define in which order to we do it.
# Actually we can't do any heuristics as you need global information in order to define clear ordering.
# We make a BIG assumption that only ONE rank receives from higher rank and sends to higher rank.
# In this case we find the "lowest" rank, send first
# All the other ranks receive first and send after
# Lowest rank receives.
# If we knew who was involved in the cycle, we could just randomly choose one rank to first send then recv, however it's not clear who's involved
p2p = send_activation.p2p
assert p2p == recv_activation.p2p
is_lowest = send_activation.to_rank > dist.get_rank(
p2p.pg
) and recv_activation.from_rank > dist.get_rank(p2p.pg)
if is_lowest:
send_activation()
self.activations_buffer.append(recv_activation())
else:
self.activations_buffer.append(recv_activation())
send_activation()