in optimum/neuron/models/training/loss_utils.py [0:0]
def forward(ctx, vocab_parallel_logits, target, ignore_index=-100, reduction="mean", label_smoothing=0.0):
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=get_tensor_model_parallel_group(),
)
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
get_vocab_range = EmbeddingUtility.range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_size()
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
# Original implementation:
# masked_target = target.clone() - vocab_start_index
# masked_target[target_mask] = 0
# New xla friendly implementation:
is_not_ignore_index_mask = (target != ignore_index).to(vocab_parallel_logits.dtype)
target_mask = (target >= vocab_start_index) & (target < vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target = torch.mul(masked_target, target_mask.long())
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device, dtype=torch.long)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
# Original implementation:
# predicted_logits[target_mask] = 0.0
# New xla friendly implementation:
predicted_logits = torch.mul(predicted_logits, target_mask.float())
# All reduce is needed to get the chunks from other devices.
torch.distributed.all_reduce(
predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
# Sum of exponential of logits along vocab dimension across all devices.
# Original implementation:
# exp_logits = vocab_parallel_logits
# torch.exp(vocab_parallel_logits, out=exp_logits)
# New xla friendly implementation:
exp_logits = torch.exp(vocab_parallel_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(
sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
vocab_size = exp_logits.size(-1)
if label_smoothing > 0:
"""
We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
= (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
= (1 - alpha) * y_gt + (alpha / (K - 1)) * sum_{i != gt} y_i
= ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * sum_{i != gt} y_i
= (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * sum_{i} y_i
= (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * sum_{i} y_i / K
From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
"""
assert 1.0 > label_smoothing > 0.0
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
log_probs = torch.log(exp_logits)
mean_log_probs = log_probs.mean(dim=-1)
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
# Zerooing the loss for the ignored tokens.
loss = loss * is_not_ignore_index_mask
# Apply the reduction, to respect the torch.nn.functional.cross_entropy_loss API
# the reduction happens only on the non-ignored tokens.
if reduction == "sum":
loss = loss.sum()
elif reduction == "mean":
num_non_ignored_tokens = is_not_ignore_index_mask.sum()
loss = loss.sum() / num_non_ignored_tokens
ctx.reduction = reduction
ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d, is_not_ignore_index_mask)
return loss