workload_generator/generate_megatron_workload.py (384 lines of code) (raw):

""" Copyright (c) 2021, Alibaba Group; Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ #!/bin/python """example of running megatron on gpt-7B python -m workload_generator.megatron_workload \ --frame=Megatron --world_size=16 --tensor_model_parallel_size=8 --pipeline_model_parallel=1 --global_batch=64 --micro_batch=2 \ --num_layers=32 --seq_length=2048 --hidden_size=4096 --epoch_num=2 --use-distributed-optimizer --enable_sequence_parallel """ from utils.utils import CommGroup, CommType, get_params, WorkloadWriter from workload_generator.workload_generator import WorkloadGenerator from workload_generator.mocked_model.MockedMegatron import MegatronModel from log_analyzer.log import LogItem class MegatronWorkload(WorkloadGenerator): def __init__(self, args, model): super().__init__(args, model) self.name = "megatron" self.args = args self.tp_is_enable = True if args.tensor_model_parallel_size > 1 else False # print(f"total params: {self._get_total_params()}") def _get_total_params(self): total_params = 0 for param in self.model.parameters(): total_params += param.numel() return total_params def _get_layernorm_params(self): total_params = 0 for param in self.model.parameters(): if getattr(param, "sequence_parallel", False): total_params += param.numel() return total_params def init(self): args = self.args self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.dp_group, comm_group_size=self.args.dp_num, msg_size=1 * 8, stage="init.model_setup", ) ) for _ in range(3): self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.dp_group, comm_group_size=self.args.dp_num, msg_size=1 * 8, stage="init.model_setup", ) ) if args.pipeline_model_parallel > 1: self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.pp_group, comm_group_size=self.args.pipeline_model_parallel, msg_size=1 * 8, stage="init.model_setup", ) ) # time self.workload.append( LogItem( comm_type=CommType.all_gather, comm_group=CommGroup.dp_group, comm_group_size=self.args.dp_num, msg_size=4 * 8, stage="init.model_setup", ) ) self.workload.append( LogItem( comm_type=CommType.broadcast, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=3 * 8, stage="init.model_setup", src=0, ) ) if args.pp_rank == args.pipeline_model_parallel - 1 and args.pipeline_model_parallel > 1: for p in self.model.embedding.parameters(): self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=p.msg_size(), stage="init.model_setup", ) ) # time self.workload.append( LogItem( comm_type=CommType.all_gather, comm_group=CommGroup.dp_group, comm_group_size=self.args.dp_num, msg_size=8 * 8, stage="init.model_setup", ) ) def get_pp_rank(self, rank, world_size, pp_size): ranks_per_pp_group = world_size // pp_size pp_rank = rank // ranks_per_pp_group return pp_rank def with_pipeline_forward_backward(self): args = self.args if args.workload_only: rank = 0 else: import torch rank = torch.distributed.get_rank() world_size = args.world_size pp_rank = self.get_pp_rank(rank, world_size, args.pipeline_model_parallel) pp_num_warmup_microbatches = min( args.pipeline_model_parallel - pp_rank - 1, args.num_microbatches ) num_microbatches_remaining = args.num_microbatches - pp_num_warmup_microbatches temp = self.model.forward() # forward_comm = self._get_comm_op(temp) for _ in range(pp_num_warmup_microbatches): if pp_rank != 0: # recv_prev self.workload.append( LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", additional="recv_prev", ) ) self.workload.append( LogItem( comm_type=CommType.broadcast, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=5 * 8, stage="forward_step", src=0, ) ) self.workload.append( LogItem( comm_type=CommType.broadcast, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=8 * (args.world_size + args.seq_length * args.micro_batch), stage="forward_step", src=0, ) ) # for item in forward_comm: self.workload.extend(self.model.forward()) if pp_rank != args.pipeline_model_parallel - 1: # send_next self.workload.append( LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", additional="send_next", ) ) # recv prev if num_microbatches_remaining > 0 and pp_rank != 0: self.workload.append( LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", additional="recv_prev", ) ) for i in range(num_microbatches_remaining): last_iter = i == (num_microbatches_remaining - 1) self.workload.append( LogItem( comm_type=CommType.broadcast, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=5 * 8, stage="forward_step", src=0, ) ) self.workload.append( LogItem( comm_type=CommType.broadcast, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=8 * (args.world_size + args.seq_length * args.micro_batch), stage="forward_step", src=0, ) ) self.workload.extend(self.model.forward()) if pp_rank != args.pipeline_model_parallel - 1: # recv next self.workload.append( LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", additional="recv_next", ) ) # send next self.workload.append( LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", additional="send_next", ) ) self.workload.extend(self.model.backward()) if pp_rank != 0: if last_iter: # send prev self.workload.append( LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step", additional="send_prev", ) ) else: # send prev recv prev self.workload.append( LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step", additional="send_prev", ) ) self.workload.append( LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step", additional="recv_prev", ) ) for _ in range(pp_num_warmup_microbatches): # recv next if pp_rank != args.pipeline_model_parallel - 1: self.workload.append( LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step", additional="recv_next", ) ) self.workload.extend(self.model.backward()) # send prev if pp_rank != 0: self.workload.append( LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, comm_group_size=1, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step", additional="send_prev", ) ) def forward(self): args = self.args if self.tp_is_enable: self.workload.append( LogItem( comm_type=CommType.broadcast, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=5 * 8, stage="forward_step", src=0, ) ) self.workload.append( LogItem( comm_type=CommType.broadcast, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=8 * (args.world_size + args.seq_length * args.micro_batch), stage="forward_step", src=0, ) ) self.workload.extend(self.model.forward()) for _ in range(3): # for bf16, we need to use float32 in loss communication self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=args.micro_batch * args.seq_length * 4, stage="forward_step._VocabParallelCrossEntropy", ) ) # average_losses_across_data_parallel_group self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.dp_group, comm_group_size=self.args.dp_num, msg_size=1 * 4, stage="forward_step.average_losses_across_data_parallel_group", ) ) def backward(self): self.workload.extend(self.model.backward()) def step(self): args = self.args if args.use_distributed_optimizer: self.workload.append( LogItem( comm_type=CommType.reduce_scatter, comm_group=CommGroup.dp_group, comm_group_size=self.args.dp_num, msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel), stage="step", ) ) self.workload.append( LogItem( comm_type=CommType.all_gather, comm_group=CommGroup.dp_group, comm_group_size=self.args.dp_num, msg_size=2 * self._get_total_params() // (args.pipeline_model_parallel), stage="step", ) ) else: # 注意,如果使用过了bf16,那么梯度会使用tf32 self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.dp_group, comm_group_size=self.args.dp_num, msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel), stage="step.finish_grad_sync", ) ) self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=2 * self._get_layernorm_params() // (args.pipeline_model_parallel), stage="step._allreduce_layernorm_grads", ) ) self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, msg_size=4, stage="step.check_for_nan", ) ) if __name__ == "__main__": args = get_params() model = MegatronModel(args) workload_generator = MegatronWorkload(args, model) workload = workload_generator() filename = f"{workload_generator.name}_{args.model_name}_sp_{args.enable_sequence_parallel}_iteration_{args.epoch_num}_computationEnable_{args.computation_enable}_{args.world_size}n.csv" workload.dump(filename) if args.enable_visual: try: from visualize.generate import visualize_output base_name = filename.split(".")[0] visualize_output(f"./results/mocked_workload/{base_name}_workload.csv",True) except ImportError: print("visualize_output is not available because required library is not found") # WorkloadWriter.write_workload(workload, args, filename)