in workload_generator/generate_deepspeed_stage3_workload.py [0:0]
def _compute_for_param(self, param):
if self.stage == "forward":
if param.get_shape()[-1] != 1:
self.workload.append(
LogItem(
comm_type=CommType.computation,
msg_size=(
(self.batch_size, self.seq_len, param.get_shape()[0]),
(param.get_shape()[0], param.get_shape()[1]),
),
stage=f"{self.stage}.computation",
)
)
if self.stage == "backward":
# input grad
if param.get_shape()[-1] != 1:
self.workload.append(
LogItem(
comm_type=CommType.computation,
msg_size=(
(self.batch_size, self.seq_len, param.get_shape()[0]),
(param.get_shape()[0], param.get_shape()[1]),
),
stage=f"{self.stage}.computation",
)
)
# weight grad
self.workload.append(
LogItem(
comm_type=CommType.computation,
msg_size=(
(param.get_shape()[0], self.batch_size * self.seq_len),
(self.batch_size * self.seq_len, param.get_shape()[1]),
),
)
)