def train_step_pipeline_afab()

in picotron/pipeline_parallel/pipeline_parallel.py [0:0]


def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
    """
    Implements All-Forward-All-Backward (AFAB) pipeline parallel training.
    First performs all forward passes, then all backward passes sequentially.
    
    Args:
        model: The pipeline parallel model
        data_loader: Iterator providing training batches
        tensor_shapes: Expected shapes of tensors for communication
        device: Device to run computations on
        dtype: Data type for tensors
    """
    logging_loss: torch.float32 = 0.0
    # Store tensors to recreate computation graph during backward pass
    input_tensors, output_tensors = [], []
    requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1

    for _ in range(data_loader.grad_acc_steps): # All forward passes
        # communication: receive the activation from the previous stage
        input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
        # or fetch from data loader
        batch = next(data_loader)
        batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
        # forward: pass
        output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
        # communication: send the activation to the next stage
        pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype)
        
        # calculate loss on the last stage
        if pgm.process_group_manager.pp_is_last_stage:
            output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
            logging_loss += output_tensor.item() / data_loader.grad_acc_steps

        # Save input/output activations to reconstruct computation graph during backward pass
        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

    for ith_microbatch in range(data_loader.grad_acc_steps): # All backward passes
        if requires_grad_sync:
            is_last_iteration = (ith_microbatch == data_loader.grad_acc_steps - 1)
            model.require_backward_grad_sync = is_last_iteration
        # communication: receive the gradient from the next stage
        output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
        # Retrieve saved input/output activations in FIFO order to match forward pass sequence
        input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
        # backward: pass
        input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
        # communication: send the gradient to the previous stage
        pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)

    return logging_loss