def with_pipeline_forward_backward()

in workload_generator/generate_megatron_workload.py [0:0]


    def with_pipeline_forward_backward(self):
        args = self.args
        if args.workload_only:
            rank = 0
        else:
            import torch
            rank = torch.distributed.get_rank()
        world_size = args.world_size
        pp_rank = self.get_pp_rank(rank, world_size, args.pipeline_model_parallel)
        pp_num_warmup_microbatches = min(
            args.pipeline_model_parallel - pp_rank - 1, args.num_microbatches
        )
        num_microbatches_remaining = args.num_microbatches - pp_num_warmup_microbatches
        temp = self.model.forward()
        # forward_comm = self._get_comm_op(temp)

        for _ in range(pp_num_warmup_microbatches):
            if pp_rank != 0:
                # recv_prev
                self.workload.append(
                    LogItem(
                        comm_type=CommType.irecv,
                        comm_group=CommGroup.pp_group,
                        comm_group_size=1,
                        msg_size=2
                        * (args.hidden_size * args.seq_length * args.micro_batch),
                        stage="forward_step",
                        additional="recv_prev",
                    )
                )
            self.workload.append(
                LogItem(
                    comm_type=CommType.broadcast,
                    comm_group=CommGroup.tp_group,
                    comm_group_size=self.args.tensor_model_parallel_size,
                    msg_size=5 * 8,
                    stage="forward_step",
                    src=0,
                )
            )
            self.workload.append(
                LogItem(
                    comm_type=CommType.broadcast,
                    comm_group=CommGroup.tp_group,
                    comm_group_size=self.args.tensor_model_parallel_size,
                    msg_size=8 * (args.world_size + args.seq_length * args.micro_batch),
                    stage="forward_step",
                    src=0,
                )
            )

            # for item in forward_comm:
            self.workload.extend(self.model.forward())

            if pp_rank != args.pipeline_model_parallel - 1:
                # send_next
                self.workload.append(
                    LogItem(
                        comm_type=CommType.isend,
                        comm_group=CommGroup.pp_group,
                        comm_group_size=1,
                        msg_size=2
                        * (args.hidden_size * args.seq_length * args.micro_batch),
                        stage="forward_step",
                        additional="send_next",
                    )
                )
        # recv prev
        if num_microbatches_remaining > 0 and pp_rank != 0:
            self.workload.append(
                LogItem(
                    comm_type=CommType.irecv,
                    comm_group=CommGroup.pp_group,
                    comm_group_size=1,
                    msg_size=2
                    * (args.hidden_size * args.seq_length * args.micro_batch),
                    stage="forward_step",
                    additional="recv_prev",
                )
            )

        for i in range(num_microbatches_remaining):
            last_iter = i == (num_microbatches_remaining - 1)
            self.workload.append(
                LogItem(
                    comm_type=CommType.broadcast,
                    comm_group=CommGroup.tp_group,
                    comm_group_size=self.args.tensor_model_parallel_size,
                    msg_size=5 * 8,
                    stage="forward_step",
                    src=0,
                )
            )
            self.workload.append(
                LogItem(
                    comm_type=CommType.broadcast,
                    comm_group=CommGroup.tp_group,
                    comm_group_size=self.args.tensor_model_parallel_size,
                    msg_size=8 * (args.world_size + args.seq_length * args.micro_batch),
                    stage="forward_step",
                    src=0,
                )
            )

            self.workload.extend(self.model.forward())
            if pp_rank != args.pipeline_model_parallel - 1:
                # recv next
                self.workload.append(
                    LogItem(
                        comm_type=CommType.irecv,
                        comm_group=CommGroup.pp_group,
                        comm_group_size=1,
                        msg_size=2
                        * (args.hidden_size * args.seq_length * args.micro_batch),
                        stage="forward_step",
                        additional="recv_next",
                    )
                )
                # send next
                self.workload.append(
                    LogItem(
                        comm_type=CommType.isend,
                        comm_group=CommGroup.pp_group,
                        comm_group_size=1,
                        msg_size=2
                        * (args.hidden_size * args.seq_length * args.micro_batch),
                        stage="forward_step",
                        additional="send_next",
                    )
                )

            self.workload.extend(self.model.backward())

            if pp_rank != 0:
                if last_iter:
                    # send prev
                    self.workload.append(
                        LogItem(
                            comm_type=CommType.isend,
                            comm_group=CommGroup.pp_group,
                            comm_group_size=1,
                            msg_size=2
                            * (args.hidden_size * args.seq_length * args.micro_batch),
                            stage="backward_step",
                            additional="send_prev",
                        )
                    )
                else:
                    # send prev recv prev
                    self.workload.append(
                        LogItem(
                            comm_type=CommType.isend,
                            comm_group=CommGroup.pp_group,
                            comm_group_size=1,
                            msg_size=2
                            * (args.hidden_size * args.seq_length * args.micro_batch),
                            stage="backward_step",
                            additional="send_prev",
                        )
                    )
                    self.workload.append(
                        LogItem(
                            comm_type=CommType.irecv,
                            comm_group=CommGroup.pp_group,
                            comm_group_size=1,
                            msg_size=2
                            * (args.hidden_size * args.seq_length * args.micro_batch),
                            stage="backward_step",
                            additional="recv_prev",
                        )
                    )

        for _ in range(pp_num_warmup_microbatches):
            # recv next
            if pp_rank != args.pipeline_model_parallel - 1:
                self.workload.append(
                    LogItem(
                        comm_type=CommType.irecv,
                        comm_group=CommGroup.pp_group,
                        comm_group_size=1,
                        msg_size=2
                        * (args.hidden_size * args.seq_length * args.micro_batch),
                        stage="backward_step",
                        additional="recv_next",
                    )
                )

            self.workload.extend(self.model.backward())

            # send prev
            if pp_rank != 0:
                self.workload.append(
                    LogItem(
                        comm_type=CommType.isend,
                        comm_group=CommGroup.pp_group,
                        comm_group_size=1,
                        msg_size=2
                        * (args.hidden_size * args.seq_length * args.micro_batch),
                        stage="backward_step",
                        additional="send_prev",
                    )
                )