in optimum/fx/parallelization/parallel_layers/loss.py [0:0]
def sharded_cross_entropy_wrapper_fn(process_group: dist.ProcessGroup):
@wraps(sharded_cross_entropy)
def wrapper(
sharded_logits: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.Tensor] = None,
size_average: Optional[bool] = None,
ignore_index: int = -100,
reduce: Optional[bool] = None,
reduction: str = "mean",
label_smoothing: float = 0.0,
):
if weight is not None or ignore_index != -100 or label_smoothing != 0.0:
raise ValueError(
"Does not support weighted mode, index ignoring and label smoothing in current parallel cross entropy implementation."
)
loss: torch.Tensor = sharded_cross_entropy(sharded_logits, target, process_group)
if size_average is not None or reduce is not None:
size_average = True if size_average is None else size_average
reduce = True if reduce is None else reduce
if size_average and reduce:
reduction = "mean"
elif reduce:
reduction = "sum"
else:
reduction = "none"
if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
return loss
return wrapper