workload_generator/generate_megatron_workload.py (384 lines of code) (raw):
"""
Copyright (c) 2021, Alibaba Group;
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
#!/bin/python
"""example of running megatron on gpt-7B
python -m workload_generator.megatron_workload \
--frame=Megatron --world_size=16 --tensor_model_parallel_size=8 --pipeline_model_parallel=1 --global_batch=64 --micro_batch=2 \
--num_layers=32 --seq_length=2048 --hidden_size=4096 --epoch_num=2 --use-distributed-optimizer --enable_sequence_parallel
"""
from utils.utils import CommGroup, CommType, get_params, WorkloadWriter
from workload_generator.workload_generator import WorkloadGenerator
from workload_generator.mocked_model.MockedMegatron import MegatronModel
from log_analyzer.log import LogItem
class MegatronWorkload(WorkloadGenerator):
def __init__(self, args, model):
super().__init__(args, model)
self.name = "megatron"
self.args = args
self.tp_is_enable = True if args.tensor_model_parallel_size > 1 else False
# print(f"total params: {self._get_total_params()}")
def _get_total_params(self):
total_params = 0
for param in self.model.parameters():
total_params += param.numel()
return total_params
def _get_layernorm_params(self):
total_params = 0
for param in self.model.parameters():
if getattr(param, "sequence_parallel", False):
total_params += param.numel()
return total_params
def init(self):
args = self.args
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=1 * 8,
stage="init.model_setup",
)
)
for _ in range(3):
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=1 * 8,
stage="init.model_setup",
)
)
if args.pipeline_model_parallel > 1:
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.pp_group,
comm_group_size=self.args.pipeline_model_parallel,
msg_size=1 * 8,
stage="init.model_setup",
)
)
# time
self.workload.append(
LogItem(
comm_type=CommType.all_gather,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=4 * 8,
stage="init.model_setup",
)
)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=3 * 8,
stage="init.model_setup",
src=0,
)
)
if args.pp_rank == args.pipeline_model_parallel - 1 and args.pipeline_model_parallel > 1:
for p in self.model.embedding.parameters():
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=p.msg_size(),
stage="init.model_setup",
)
)
# time
self.workload.append(
LogItem(
comm_type=CommType.all_gather,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=8 * 8,
stage="init.model_setup",
)
)
def get_pp_rank(self, rank, world_size, pp_size):
ranks_per_pp_group = world_size // pp_size
pp_rank = rank // ranks_per_pp_group
return pp_rank
def with_pipeline_forward_backward(self):
args = self.args
if args.workload_only:
rank = 0
else:
import torch
rank = torch.distributed.get_rank()
world_size = args.world_size
pp_rank = self.get_pp_rank(rank, world_size, args.pipeline_model_parallel)
pp_num_warmup_microbatches = min(
args.pipeline_model_parallel - pp_rank - 1, args.num_microbatches
)
num_microbatches_remaining = args.num_microbatches - pp_num_warmup_microbatches
temp = self.model.forward()
# forward_comm = self._get_comm_op(temp)
for _ in range(pp_num_warmup_microbatches):
if pp_rank != 0:
# recv_prev
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="recv_prev",
)
)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=5 * 8,
stage="forward_step",
src=0,
)
)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=8 * (args.world_size + args.seq_length * args.micro_batch),
stage="forward_step",
src=0,
)
)
# for item in forward_comm:
self.workload.extend(self.model.forward())
if pp_rank != args.pipeline_model_parallel - 1:
# send_next
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="send_next",
)
)
# recv prev
if num_microbatches_remaining > 0 and pp_rank != 0:
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="recv_prev",
)
)
for i in range(num_microbatches_remaining):
last_iter = i == (num_microbatches_remaining - 1)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=5 * 8,
stage="forward_step",
src=0,
)
)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=8 * (args.world_size + args.seq_length * args.micro_batch),
stage="forward_step",
src=0,
)
)
self.workload.extend(self.model.forward())
if pp_rank != args.pipeline_model_parallel - 1:
# recv next
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="recv_next",
)
)
# send next
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="send_next",
)
)
self.workload.extend(self.model.backward())
if pp_rank != 0:
if last_iter:
# send prev
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="send_prev",
)
)
else:
# send prev recv prev
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="send_prev",
)
)
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="recv_prev",
)
)
for _ in range(pp_num_warmup_microbatches):
# recv next
if pp_rank != args.pipeline_model_parallel - 1:
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="recv_next",
)
)
self.workload.extend(self.model.backward())
# send prev
if pp_rank != 0:
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="send_prev",
)
)
def forward(self):
args = self.args
if self.tp_is_enable:
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=5 * 8,
stage="forward_step",
src=0,
)
)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=8 * (args.world_size + args.seq_length * args.micro_batch),
stage="forward_step",
src=0,
)
)
self.workload.extend(self.model.forward())
for _ in range(3):
# for bf16, we need to use float32 in loss communication
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=args.micro_batch * args.seq_length * 4,
stage="forward_step._VocabParallelCrossEntropy",
)
)
# average_losses_across_data_parallel_group
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=1 * 4,
stage="forward_step.average_losses_across_data_parallel_group",
)
)
def backward(self):
self.workload.extend(self.model.backward())
def step(self):
args = self.args
if args.use_distributed_optimizer:
self.workload.append(
LogItem(
comm_type=CommType.reduce_scatter,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel),
stage="step",
)
)
self.workload.append(
LogItem(
comm_type=CommType.all_gather,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=2 * self._get_total_params() // (args.pipeline_model_parallel),
stage="step",
)
)
else:
# 注意,如果使用过了bf16,那么梯度会使用tf32
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.dp_group,
comm_group_size=self.args.dp_num,
msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel),
stage="step.finish_grad_sync",
)
)
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=2 * self._get_layernorm_params() // (args.pipeline_model_parallel),
stage="step._allreduce_layernorm_grads",
)
)
self.workload.append(
LogItem(
comm_type=CommType.all_reduce,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=4,
stage="step.check_for_nan",
)
)
if __name__ == "__main__":
args = get_params()
model = MegatronModel(args)
workload_generator = MegatronWorkload(args, model)
workload = workload_generator()
filename = f"{workload_generator.name}_{args.model_name}_sp_{args.enable_sequence_parallel}_iteration_{args.epoch_num}_computationEnable_{args.computation_enable}_{args.world_size}n.csv"
workload.dump(filename)
if args.enable_visual:
try:
from visualize.generate import visualize_output
base_name = filename.split(".")[0]
visualize_output(f"./results/mocked_workload/{base_name}_workload.csv",True)
except ImportError:
print("visualize_output is not available because required library is not found")
# WorkloadWriter.write_workload(workload, args, filename)