optimum/habana/distributed/contextparallel.py (39 lines of code) (raw):
import torch
from .parallel_state import (
get_sequence_parallel_group,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
)
class ContextParallelLossFunction(torch.autograd.Function):
"""
Gather losses across context parallel group.
This custom autograd function is designed to handle the distribution of loss computation
across multiple parallel contexts in a distributed training setup. It ensures that the loss
is gathered from all devices involved in the parallel context, allowing for consistent and
accurate computation of the overall loss.
The forward method gathers the loss from all ranks in the context parallel group, while the
backward method ensures that gradients are correctly synchronized across the different parallel
contexts.
"""
@staticmethod
def forward(ctx, loss):
ctx.seqlen = loss.size(0) * get_sequence_parallel_world_size()
# Create a tensor to gather all losses from context parallel group
loss_all = torch.empty(ctx.seqlen, dtype=loss.dtype, device=loss.device)
# Gather losses from all ranks in the group
torch.distributed.all_gather_into_tensor(loss_all, loss, group=get_sequence_parallel_group())
return loss_all
@staticmethod
def backward(ctx, grad_output):
step_seqlen = ctx.seqlen // get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
# Extract the relevant part of the gradient for this rank
grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)]
return grad_output_part, None
def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
loss_all = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
# Apply context parallel loss
loss_all = ContextParallelLossFunction.apply(loss_all)
if num_items_in_batch is None:
loss = torch.mean(loss_all)
else:
loss = torch.sum(loss_all) / num_items_in_batch
return loss
def ForCausalLMContextParallelLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss