def pipeline_communicate()

in step8_pipeline_parallel_1f1b/pipeline_parallel.py [0:0]


def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None):
    """
    Handles point-to-point communication between pipeline stages for forward and backward passes.
    
    Args:
        operation (str): Type of communication operation ('recv_forward', 'send_forward', 
                        'recv_backward', 'send_backward')
        device: Target device for tensor operations (e.g., CPU, GPU)
        dtype: Data type for tensors
        tensor: Input tensor for send operations (default: None)
        shapes: Shape specifications for receiving tensors (default: None)
    
    Returns:
        torch.Tensor or None: Received tensor for receive operations, None for send operations
    """
    global STEP
    global VERBOSE
    
    if operation == 'recv_forward':
        # Skip if this is the first pipeline stage (nothing to receive)
        if pgm.process_group_manager.pp_is_first_stage: return None
        # Create empty tensor to receive data
        tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
        src = pgm.process_group_manager.pp_prev_rank
    
    elif operation == 'send_forward':
        # Skip if this is the last pipeline stage (nothing to send forward)
        if pgm.process_group_manager.pp_is_last_stage: return
        dest = pgm.process_group_manager.pp_next_rank
    
    elif operation == 'recv_backward':
        # Skip if this is the last pipeline stage (nothing to receive from backward)
        if pgm.process_group_manager.pp_is_last_stage: return None
        tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
        src = pgm.process_group_manager.pp_next_rank
    
    elif operation == 'send_backward':
        # Skip if this is the first pipeline stage (nothing to send backward)
        if pgm.process_group_manager.pp_is_first_stage: return
        dest = pgm.process_group_manager.pp_prev_rank

    # Determine if this is a send operation and set peer rank
    is_send = operation.startswith('send')
    peer_rank = dest if is_send else src
    
    # Create P2P operation (send or receive)
    op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
    
    if VERBOSE: 
        print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} "
              f"{pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | "
              f"STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
    
    # Execute communication operation and wait for completion
    [req.wait() for req in dist.batch_isend_irecv([op])]
    torch.cuda.synchronize()
    
    if VERBOSE: STEP += 1
    
    # Return received tensor for receive operations, None for send operations
    return tensor if not is_send else None