in workload_generator/generate_megatron_workload.py [0:0]
def init(self):
args = self.args
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=1 * 8,
stage="init.model_setup",
)
)
for _ in range(3):
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=1 * 8,
stage="init.model_setup",
)
)
if args.pipeline_model_parallel > 1:
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.pp_group,
comm_group_size=self.args.pipeline_model_parallel,
msg_size=1 * 8,
stage="init.model_setup",
)
)
# time
self.workload.append(
LogItem(
comm_type=CommType.all_gather,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=4 * 8,
stage="init.model_setup",
)
)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=3 * 8,
stage="init.model_setup",
src=0,
)
)
if args.pp_rank == args.pipeline_model_parallel - 1 and args.pipeline_model_parallel > 1:
for p in self.model.embedding.parameters():
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=p.msg_size(),
stage="init.model_setup",
)
)
# time
self.workload.append(
LogItem(
comm_type=CommType.all_gather,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=8 * 8,
stage="init.model_setup",
)
)