in workload_generator/generate_megatron_workload.py [0:0]
def with_pipeline_forward_backward(self):
args = self.args
if args.workload_only:
rank = 0
else:
import torch
rank = torch.distributed.get_rank()
world_size = args.world_size
pp_rank = self.get_pp_rank(rank, world_size, args.pipeline_model_parallel)
pp_num_warmup_microbatches = min(
args.pipeline_model_parallel - pp_rank - 1, args.num_microbatches
)
num_microbatches_remaining = args.num_microbatches - pp_num_warmup_microbatches
temp = self.model.forward()
# forward_comm = self._get_comm_op(temp)
for _ in range(pp_num_warmup_microbatches):
if pp_rank != 0:
# recv_prev
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="recv_prev",
)
)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=5 * 8,
stage="forward_step",
src=0,
)
)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=8 * (args.world_size + args.seq_length * args.micro_batch),
stage="forward_step",
src=0,
)
)
# for item in forward_comm:
self.workload.extend(self.model.forward())
if pp_rank != args.pipeline_model_parallel - 1:
# send_next
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="send_next",
)
)
# recv prev
if num_microbatches_remaining > 0 and pp_rank != 0:
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="recv_prev",
)
)
for i in range(num_microbatches_remaining):
last_iter = i == (num_microbatches_remaining - 1)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=5 * 8,
stage="forward_step",
src=0,
)
)
self.workload.append(
LogItem(
comm_type=CommType.broadcast,
comm_group=CommGroup.tp_group,
comm_group_size=self.args.tensor_model_parallel_size,
msg_size=8 * (args.world_size + args.seq_length * args.micro_batch),
stage="forward_step",
src=0,
)
)
self.workload.extend(self.model.forward())
if pp_rank != args.pipeline_model_parallel - 1:
# recv next
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="recv_next",
)
)
# send next
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="forward_step",
additional="send_next",
)
)
self.workload.extend(self.model.backward())
if pp_rank != 0:
if last_iter:
# send prev
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="send_prev",
)
)
else:
# send prev recv prev
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="send_prev",
)
)
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="recv_prev",
)
)
for _ in range(pp_num_warmup_microbatches):
# recv next
if pp_rank != args.pipeline_model_parallel - 1:
self.workload.append(
LogItem(
comm_type=CommType.irecv,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="recv_next",
)
)
self.workload.extend(self.model.backward())
# send prev
if pp_rank != 0:
self.workload.append(
LogItem(
comm_type=CommType.isend,
comm_group=CommGroup.pp_group,
comm_group_size=1,
msg_size=2
* (args.hidden_size * args.seq_length * args.micro_batch),
stage="backward_step",
additional="send_prev",
)
)