def bidirectional_pipeline_communicate()

in step8_pipeline_parallel_1f1b/pipeline_parallel.py [0:0]


def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
    """
    Handles bidirectional communication between pipeline stages, allowing simultaneous 
    send and receive operations.
    
    Args:
        operation (str): Type of bidirectional operation ('send_fwd_recv_bwd' or 'send_bwd_recv_fwd')
        send_tensor: Tensor to be sent
        recv_shapes: Shape specifications for the tensor to be received
        device: Target device for tensor operations
        dtype: Data type for tensors
    
    Returns:
        torch.Tensor or None: Received tensor, or None if at terminal pipeline stage
    """
    global STEP
    global VERBOSE
    
    # Determine if this is a forward operation
    is_fwd = (operation == 'send_fwd_recv_bwd')
    
    # Skip if at terminal pipeline stages
    if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or \
       (not is_fwd and pgm.process_group_manager.pp_is_first_stage): 
        return None
    
    # Determine peer rank based on operation direction
    peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank
    
    # Create empty tensor for receiving data
    recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
    
    # Set up simultaneous send and receive operations
    reqs = dist.batch_isend_irecv([
        dist.P2POp(dist.isend, send_tensor, peer_rank),
        dist.P2POp(dist.irecv, recv_tensor, peer_rank)
    ])
    
    if VERBOSE: 
        print(f"{operation} | sending {'next' if is_fwd else 'prev'} "
              f"{pgm.process_group_manager.pp_rank} -> {peer_rank} | "
              f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> "
              f"{pgm.process_group_manager.pp_rank} | STEP {STEP=} | "
              f"RANK:{pgm.process_group_manager.pp_rank}", flush=True)
    
    # Wait for both operations to complete
    [req.wait() for req in reqs]
    torch.cuda.synchronize()
    
    if VERBOSE: STEP += 1
    
    return recv_tensor