def step()

in workload_generator/generate_megatron_workload.py [0:0]


    def step(self):
        args = self.args

        if args.use_distributed_optimizer:
            self.workload.append(
                LogItem(
                    comm_type=CommType.reduce_scatter,
                    comm_group=CommGroup.dp_group,
                    comm_group_size=self.args.dp_num,
                    msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel),
                    stage="step",
                )
            )
            self.workload.append(
                LogItem(
                    comm_type=CommType.all_gather,
                    comm_group=CommGroup.dp_group,
                    comm_group_size=self.args.dp_num,
                    msg_size=2 * self._get_total_params() // (args.pipeline_model_parallel),
                    stage="step",
                )
            )
        else:
            # 注意,如果使用过了bf16,那么梯度会使用tf32
            self.workload.append(
                LogItem(
                    comm_type=CommType.all_reduce,
                    comm_group=CommGroup.dp_group,
                    comm_group_size=self.args.dp_num,
                    msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel),
                    stage="step.finish_grad_sync",
                )
            )

        self.workload.append(
            LogItem(
                comm_type=CommType.all_reduce,
                comm_group=CommGroup.tp_group,
                comm_group_size=self.args.tensor_model_parallel_size,
                msg_size=2 * self._get_layernorm_params() // (args.pipeline_model_parallel),
                stage="step._allreduce_layernorm_grads",
            )
        )

        self.workload.append(
            LogItem(
                comm_type=CommType.all_reduce,
                comm_group=CommGroup.tp_group,
                comm_group_size=self.args.tensor_model_parallel_size,
                msg_size=4,
                stage="step.check_for_nan",
            )
        )