def _apply_p2pcommunication()

in workload_applyer.py [0:0]


    def _apply_p2pcommunication(self, item):
        ops = []
        tensor = torch.narrow(self.buffer, 0, 0, item.msg_size // 2)
        if item.additional == "send_prev":
            if self._get_pipeline_parallel_rank() != 0:
                send_prev_op = torch.distributed.P2POp(
                    torch.distributed.isend, tensor, self._get_pipeline_prev_rank()
                )
                ops.append(send_prev_op)
            else:
                pass
        if item.additional == "send_next":
            if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1:
                send_next_op = torch.distributed.P2POp(
                    torch.distributed.isend, tensor, self._get_pipeline_next_rank()
                )
                ops.append(send_next_op)
            else:
                pass
        if item.additional == "recv_prev":
            if self._get_pipeline_parallel_rank() != 0:
                tensor_recv_prev = torch.empty(
                    item.msg_size // 2, dtype=torch.bfloat16, device=self.device
                )
                recv_prev_op = torch.distributed.P2POp(
                    torch.distributed.irecv,
                    tensor_recv_prev,
                    self._get_pipeline_prev_rank(),
                )
                ops.append(recv_prev_op)
            else:
                pass
        if item.additional == "recv_next":
            if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1:
                tensor_recv_next = torch.empty(
                    item.msg_size // 2, dtype=torch.bfloat16, device=self.device
                )
                recv_next_op = torch.distributed.P2POp(
                    torch.distributed.irecv,
                    tensor_recv_next,
                    self._get_pipeline_next_rank(),
                )
                ops.append(recv_next_op)
            else:
                pass
        if len(ops) > 0:
            reqs = torch.distributed.batch_isend_irecv(ops)
            for req in reqs:
                req.wait()

        torch.cuda.synchronize()