in workload_applyer.py [0:0]
def _apply_p2pcommunication(self, item):
ops = []
tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2)
if item.additional == "send_prev":
if self._get_pipeline_parallel_rank() != 0:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor, self._get_pipeline_prev_rank()
)
ops.append(send_prev_op)
else:
pass
if item.additional == "send_next":
if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor, self._get_pipeline_next_rank()
)
ops.append(send_next_op)
else:
pass
if item.additional == "recv_prev":
if self._get_pipeline_parallel_rank() != 0:
tensor_recv_prev = torch.empty(
item.msg_size // 2, dtype=torch.bfloat16, device=self.device
)
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
self._get_pipeline_prev_rank(),
)
ops.append(recv_prev_op)
else:
pass
if item.additional == "recv_next":
if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1:
tensor_recv_next = torch.empty(
item.msg_size // 2, dtype=torch.bfloat16, device=self.device
)
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
self._get_pipeline_next_rank(),
)
ops.append(recv_next_op)
else:
pass
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()