in workload_generator/generate_deepspeed_stage1_2_workload.py [0:0]
def init(self):
if not self.amp_enabled:
for param in self.model.parameters():
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.dp_group,
comm_group_size=self.dp_world_size,
msg_size=param.msg_size(),
stage="init",
)
)
self.workload.append(
LogItem(
comm_type=CommType.barrier,
comm_group=CommGroup.all,
comm_group_size=self.dp_world_size,
msg_size=param.msg_size(),
stage="init.__init__",
)
)