in picotron/pipeline_parallel/pp_communications.py [0:0]
def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
global STEP
global VERBOSE
is_fwd = (operation == 'send_fwd_recv_bwd')
if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None
peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
[req.wait() for req in reqs]
torch.cuda.synchronize()
if VERBOSE: STEP += 1
return recv_tensor