in step7_pipeline_parallel_afab/pipeline_parallel.py [0:0]
def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None):
"""
Handles point-to-point communication between pipeline stages for forward and backward passes.
Args:
operation (str): Type of communication operation ('recv_forward', 'send_forward',
'recv_backward', 'send_backward')
device: Target device for tensor operations (e.g., CPU, GPU)
dtype: Data type for tensors
tensor: Input tensor for send operations (default: None)
shapes: Shape specifications for receiving tensors (default: None)
Returns:
torch.Tensor or None: Received tensor for receive operations, None for send operations
"""
global STEP
global VERBOSE
if operation == 'recv_forward':
# Skip if this is the first pipeline stage (nothing to receive)
if pgm.process_group_manager.pp_is_first_stage: return None
# Create empty tensor to receive data
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_prev_rank
elif operation == 'send_forward':
# Skip if this is the last pipeline stage (nothing to send forward)
if pgm.process_group_manager.pp_is_last_stage: return
dest = pgm.process_group_manager.pp_next_rank
elif operation == 'recv_backward':
# Skip if this is the last pipeline stage (nothing to receive from backward)
if pgm.process_group_manager.pp_is_last_stage: return None
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
src = pgm.process_group_manager.pp_next_rank
elif operation == 'send_backward':
# Skip if this is the first pipeline stage (nothing to send backward)
if pgm.process_group_manager.pp_is_first_stage: return
dest = pgm.process_group_manager.pp_prev_rank
# Determine if this is a send operation and set peer rank
is_send = operation.startswith('send')
peer_rank = dest if is_send else src
# Create P2P operation (send or receive)
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
if VERBOSE:
print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} "
f"{pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | "
f"STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
# Execute communication operation and wait for completion
[req.wait() for req in dist.batch_isend_irecv([op])]
torch.cuda.synchronize()
if VERBOSE: STEP += 1
# Return received tensor for receive operations, None for send operations
return tensor if not is_send else None