in step8_pipeline_parallel_1f1b/pipeline_parallel.py [0:0]
def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
"""
Handles bidirectional communication between pipeline stages, allowing simultaneous
send and receive operations.
Args:
operation (str): Type of bidirectional operation ('send_fwd_recv_bwd' or 'send_bwd_recv_fwd')
send_tensor: Tensor to be sent
recv_shapes: Shape specifications for the tensor to be received
device: Target device for tensor operations
dtype: Data type for tensors
Returns:
torch.Tensor or None: Received tensor, or None if at terminal pipeline stage
"""
global STEP
global VERBOSE
# Determine if this is a forward operation
is_fwd = (operation == 'send_fwd_recv_bwd')
# Skip if at terminal pipeline stages
if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or \
(not is_fwd and pgm.process_group_manager.pp_is_first_stage):
return None
# Determine peer rank based on operation direction
peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank
# Create empty tensor for receiving data
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
# Set up simultaneous send and receive operations
reqs = dist.batch_isend_irecv([
dist.P2POp(dist.isend, send_tensor, peer_rank),
dist.P2POp(dist.irecv, recv_tensor, peer_rank)
])
if VERBOSE:
print(f"{operation} | sending {'next' if is_fwd else 'prev'} "
f"{pgm.process_group_manager.pp_rank} -> {peer_rank} | "
f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> "
f"{pgm.process_group_manager.pp_rank} | STEP {STEP=} | "
f"RANK:{pgm.process_group_manager.pp_rank}", flush=True)
# Wait for both operations to complete
[req.wait() for req in reqs]
torch.cuda.synchronize()
if VERBOSE: STEP += 1
return recv_tensor