in picotron/context_parallel/cp_communications.py [0:0]
def wait(self):
if self._active_requests is None: raise RuntimeError("Wait called before commit")
for i, request in enumerate(self._active_requests):
request.wait()
if VERBOSE:
operation_type = "send" if i % 2 == 0 else "receive"
peer_rank = self.send_rank if operation_type == "send" else self.recv_rank
print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:completed_{operation_type} | "f"{'FROM' if operation_type == 'receive' else 'TO'}:{peer_rank}", flush=True)
torch.cuda.synchronize()
self._active_requests = None
self._pending_operations = []
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)