def _compute_for_param()

in workload_generator/generate_deepspeed_stage3_workload.py [0:0]


    def _compute_for_param(self, param):
        if self.stage == "forward":
            if param.get_shape()[-1] != 1:
                self.workload.append(
                    LogItem(
                        comm_type=CommType.computation,
                        msg_size=(
                            (self.batch_size, self.seq_len, param.get_shape()[0]),
                            (param.get_shape()[0], param.get_shape()[1]),
                        ),
                        stage=f"{self.stage}.computation",
                    )
                )
        if self.stage == "backward":
            # input grad
            if param.get_shape()[-1] != 1:
                self.workload.append(
                    LogItem(
                        comm_type=CommType.computation,
                        msg_size=(
                            (self.batch_size, self.seq_len, param.get_shape()[0]),
                            (param.get_shape()[0], param.get_shape()[1]),
                        ),
                        stage=f"{self.stage}.computation",
                    )
                )

                # weight grad
                self.workload.append(
                    LogItem(
                        comm_type=CommType.computation,
                        msg_size=(
                            (param.get_shape()[0], self.batch_size * self.seq_len),
                            (self.batch_size * self.seq_len, param.get_shape()[1]),
                        ),
                    )
                )