in step8_pipeline_parallel_1f1b/pipeline_parallel.py [0:0]
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.grad_acc_steps)
num_microbatches_remaining = data_loader.grad_acc_steps - num_warmup_microbatches
logging_loss, input_tensors, output_tensors = 0.0, [], []
requires_grad_sync = pgm.process_group_manager.dp_world_size > 1
def _forward_step(input_tensor):
batch = next(data_loader)
batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
# calculate loss on the last stage
if pgm.process_group_manager.pp_is_last_stage:
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
nonlocal logging_loss
logging_loss += output_tensor.item() / data_loader.grad_acc_steps
return output_tensor
# === Warmup forward passes ===
for _ in range(num_warmup_microbatches):
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
output_tensor = _forward_step(input_tensor)
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if num_microbatches_remaining > 0:
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
if requires_grad_sync:
model.require_backward_grad_sync = False
# === 1F1B steady state ===
for ith_microbatch in range(num_microbatches_remaining):
is_last_iteration = (ith_microbatch == num_microbatches_remaining - 1)
output_tensor = _forward_step(input_tensor)
output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=dtype)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
# Trigger gradient sync on the last microbatch but only when last rank (the one that has num_warmup_microbatches = 0) has finished computing its backward pass.
if num_warmup_microbatches == 0 and is_last_iteration:
model.require_backward_grad_sync = True
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
if is_last_iteration:
input_tensor = None
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)
else:
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=dtype)
# === Cooldown backward passes ===
for ith_warmup_microbatches in range(num_warmup_microbatches):
if requires_grad_sync:
is_last_iteration = (ith_warmup_microbatches == num_warmup_microbatches - 1)
model.require_backward_grad_sync = (ith_warmup_microbatches == num_warmup_microbatches - 1)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)
return logging_loss